1use crate::{content_encoding::SupportedEncodings, BoxError};
4use bytes::{Buf, Bytes, BytesMut};
5use futures_core::Stream;
6use http::HeaderValue;
7use http_body::{Body, Frame};
8use pin_project_lite::pin_project;
9use std::{
10 io,
11 pin::Pin,
12 task::{ready, Context, Poll},
13};
14use tokio::io::AsyncRead;
15use tokio_util::io::StreamReader;
16
17#[derive(Debug, Clone, Copy)]
18pub(crate) struct AcceptEncoding {
19 pub(crate) gzip: bool,
20 pub(crate) deflate: bool,
21 pub(crate) br: bool,
22 pub(crate) zstd: bool,
23}
24
25impl AcceptEncoding {
26 #[allow(dead_code)]
27 pub(crate) fn to_header_value(self) -> Option<HeaderValue> {
28 let accept = match (self.gzip(), self.deflate(), self.br(), self.zstd()) {
29 (true, true, true, false) => "gzip,deflate,br",
30 (true, true, false, false) => "gzip,deflate",
31 (true, false, true, false) => "gzip,br",
32 (true, false, false, false) => "gzip",
33 (false, true, true, false) => "deflate,br",
34 (false, true, false, false) => "deflate",
35 (false, false, true, false) => "br",
36 (true, true, true, true) => "zstd,gzip,deflate,br",
37 (true, true, false, true) => "zstd,gzip,deflate",
38 (true, false, true, true) => "zstd,gzip,br",
39 (true, false, false, true) => "zstd,gzip",
40 (false, true, true, true) => "zstd,deflate,br",
41 (false, true, false, true) => "zstd,deflate",
42 (false, false, true, true) => "zstd,br",
43 (false, false, false, true) => "zstd",
44 (false, false, false, false) => return None,
45 };
46 Some(HeaderValue::from_static(accept))
47 }
48
49 #[allow(dead_code)]
50 pub(crate) fn set_gzip(&mut self, enable: bool) {
51 self.gzip = enable;
52 }
53
54 #[allow(dead_code)]
55 pub(crate) fn set_deflate(&mut self, enable: bool) {
56 self.deflate = enable;
57 }
58
59 #[allow(dead_code)]
60 pub(crate) fn set_br(&mut self, enable: bool) {
61 self.br = enable;
62 }
63
64 #[allow(dead_code)]
65 pub(crate) fn set_zstd(&mut self, enable: bool) {
66 self.zstd = enable;
67 }
68}
69
70impl SupportedEncodings for AcceptEncoding {
71 #[allow(dead_code)]
72 fn gzip(&self) -> bool {
73 #[cfg(any(feature = "decompression-gzip", feature = "compression-gzip"))]
74 return self.gzip;
75
76 #[cfg(not(any(feature = "decompression-gzip", feature = "compression-gzip")))]
77 return false;
78 }
79
80 #[allow(dead_code)]
81 fn deflate(&self) -> bool {
82 #[cfg(any(feature = "decompression-deflate", feature = "compression-deflate"))]
83 return self.deflate;
84
85 #[cfg(not(any(feature = "decompression-deflate", feature = "compression-deflate")))]
86 return false;
87 }
88
89 #[allow(dead_code)]
90 fn br(&self) -> bool {
91 #[cfg(any(feature = "decompression-br", feature = "compression-br"))]
92 return self.br;
93
94 #[cfg(not(any(feature = "decompression-br", feature = "compression-br")))]
95 return false;
96 }
97
98 #[allow(dead_code)]
99 fn zstd(&self) -> bool {
100 #[cfg(any(feature = "decompression-zstd", feature = "compression-zstd"))]
101 return self.zstd;
102
103 #[cfg(not(any(feature = "decompression-zstd", feature = "compression-zstd")))]
104 return false;
105 }
106}
107
108impl Default for AcceptEncoding {
109 fn default() -> Self {
110 AcceptEncoding {
111 gzip: true,
112 deflate: true,
113 br: true,
114 zstd: true,
115 }
116 }
117}
118
119pub(crate) type AsyncReadBody<B> =
121 StreamReader<StreamErrorIntoIoError<BodyIntoStream<B>, <B as Body>::Error>, <B as Body>::Data>;
122
123pub(crate) trait DecorateAsyncRead {
125 type Input: AsyncRead;
126 type Output: AsyncRead;
127
128 fn apply(input: Self::Input, quality: CompressionLevel) -> Self::Output;
130
131 fn get_pin_mut(pinned: Pin<&mut Self::Output>) -> Pin<&mut Self::Input>;
135}
136
137pin_project! {
138 pub(crate) struct WrapBody<M: DecorateAsyncRead> {
140 #[pin]
141 pub read: M::Output,
144 buf: BytesMut,
147 read_all_data: bool,
148 }
149}
150
151impl<M: DecorateAsyncRead> WrapBody<M> {
152 const INTERNAL_BUF_CAPACITY: usize = 4096;
153}
154
155impl<M: DecorateAsyncRead> WrapBody<M> {
156 #[allow(dead_code)]
157 pub(crate) fn new<B>(body: B, quality: CompressionLevel) -> Self
158 where
159 B: Body,
160 M: DecorateAsyncRead<Input = AsyncReadBody<B>>,
161 {
162 let stream = BodyIntoStream::new(body);
164
165 let stream = StreamErrorIntoIoError::<_, B::Error>::new(stream);
168
169 let read = StreamReader::new(stream);
171
172 let read = M::apply(read, quality);
174
175 Self {
176 read,
177 buf: BytesMut::with_capacity(Self::INTERNAL_BUF_CAPACITY),
178 read_all_data: false,
179 }
180 }
181}
182
183impl<B, M> Body for WrapBody<M>
184where
185 B: Body,
186 B::Error: Into<BoxError>,
187 M: DecorateAsyncRead<Input = AsyncReadBody<B>>,
188{
189 type Data = Bytes;
190 type Error = BoxError;
191
192 fn poll_frame(
193 self: Pin<&mut Self>,
194 cx: &mut Context<'_>,
195 ) -> Poll<Option<Result<http_body::Frame<Self::Data>, Self::Error>>> {
196 let mut this = self.project();
197
198 if !*this.read_all_data {
199 if this.buf.capacity() == 0 {
200 this.buf.reserve(Self::INTERNAL_BUF_CAPACITY);
201 }
202
203 let result = tokio_util::io::poll_read_buf(this.read.as_mut(), cx, &mut this.buf);
204
205 match ready!(result) {
206 Ok(0) => {
207 *this.read_all_data = true;
208 }
209 Ok(_) => {
210 let chunk = this.buf.split().freeze();
211 return Poll::Ready(Some(Ok(Frame::data(chunk))));
212 }
213 Err(err) => {
214 let body_error: Option<B::Error> = M::get_pin_mut(this.read)
215 .get_pin_mut()
216 .project()
217 .error
218 .take();
219
220 if let Some(body_error) = body_error {
221 return Poll::Ready(Some(Err(body_error.into())));
222 } else if err.raw_os_error() == Some(SENTINEL_ERROR_CODE) {
223 unreachable!()
226 } else {
227 return Poll::Ready(Some(Err(err.into())));
228 }
229 }
230 }
231 }
232
233 let body = M::get_pin_mut(this.read).get_pin_mut().get_pin_mut();
235 body.poll_frame(cx).map(|option| {
236 option.map(|result| {
237 result
238 .map(|frame| frame.map_data(|mut data| data.copy_to_bytes(data.remaining())))
239 .map_err(|err| err.into())
240 })
241 })
242 }
243}
244
245pin_project! {
246 pub(crate) struct BodyIntoStream<B>
247 where
248 B: Body,
249 {
250 #[pin]
251 body: B,
252 yielded_all_data: bool,
253 non_data_frame: Option<Frame<B::Data>>,
254 }
255}
256
257#[allow(dead_code)]
258impl<B> BodyIntoStream<B>
259where
260 B: Body,
261{
262 pub(crate) fn new(body: B) -> Self {
263 Self {
264 body,
265 yielded_all_data: false,
266 non_data_frame: None,
267 }
268 }
269
270 pub(crate) fn get_ref(&self) -> &B {
272 &self.body
273 }
274
275 pub(crate) fn get_mut(&mut self) -> &mut B {
277 &mut self.body
278 }
279
280 pub(crate) fn get_pin_mut(self: Pin<&mut Self>) -> Pin<&mut B> {
282 self.project().body
283 }
284
285 pub(crate) fn into_inner(self) -> B {
287 self.body
288 }
289}
290
291impl<B> Stream for BodyIntoStream<B>
292where
293 B: Body,
294{
295 type Item = Result<B::Data, B::Error>;
296
297 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
298 loop {
299 let this = self.as_mut().project();
300
301 if *this.yielded_all_data {
302 return Poll::Ready(None);
303 }
304
305 match std::task::ready!(this.body.poll_frame(cx)) {
306 Some(Ok(frame)) => match frame.into_data() {
307 Ok(data) => return Poll::Ready(Some(Ok(data))),
308 Err(frame) => {
309 *this.yielded_all_data = true;
310 *this.non_data_frame = Some(frame);
311 }
312 },
313 Some(Err(err)) => return Poll::Ready(Some(Err(err))),
314 None => {
315 *this.yielded_all_data = true;
316 }
317 }
318 }
319 }
320}
321
322impl<B> Body for BodyIntoStream<B>
323where
324 B: Body,
325{
326 type Data = B::Data;
327 type Error = B::Error;
328
329 fn poll_frame(
330 mut self: Pin<&mut Self>,
331 cx: &mut Context<'_>,
332 ) -> Poll<Option<Result<Frame<Self::Data>, Self::Error>>> {
333 if let Some(frame) = std::task::ready!(self.as_mut().poll_next(cx)) {
336 return Poll::Ready(Some(frame.map(Frame::data)));
337 }
338
339 let this = self.project();
340
341 if let Some(frame) = this.non_data_frame.take() {
343 return Poll::Ready(Some(Ok(frame)));
344 }
345
346 this.body.poll_frame(cx)
349 }
350
351 #[inline]
352 fn size_hint(&self) -> http_body::SizeHint {
353 self.body.size_hint()
354 }
355}
356
357pin_project! {
358 pub(crate) struct StreamErrorIntoIoError<S, E> {
359 #[pin]
360 inner: S,
361 error: Option<E>,
362 }
363}
364
365impl<S, E> StreamErrorIntoIoError<S, E> {
366 pub(crate) fn new(inner: S) -> Self {
367 Self { inner, error: None }
368 }
369
370 pub(crate) fn get_ref(&self) -> &S {
372 &self.inner
373 }
374
375 pub(crate) fn get_mut(&mut self) -> &mut S {
377 &mut self.inner
378 }
379
380 pub(crate) fn get_pin_mut(self: Pin<&mut Self>) -> Pin<&mut S> {
382 self.project().inner
383 }
384
385 pub(crate) fn into_inner(self) -> S {
387 self.inner
388 }
389}
390
391impl<S, T, E> Stream for StreamErrorIntoIoError<S, E>
392where
393 S: Stream<Item = Result<T, E>>,
394{
395 type Item = Result<T, io::Error>;
396
397 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
398 let this = self.project();
399 match ready!(this.inner.poll_next(cx)) {
400 None => Poll::Ready(None),
401 Some(Ok(value)) => Poll::Ready(Some(Ok(value))),
402 Some(Err(err)) => {
403 *this.error = Some(err);
404 Poll::Ready(Some(Err(io::Error::from_raw_os_error(SENTINEL_ERROR_CODE))))
405 }
406 }
407 }
408}
409
410pub(crate) const SENTINEL_ERROR_CODE: i32 = -837459418;
411
412#[non_exhaustive]
414#[derive(Clone, Copy, Debug, Eq, PartialEq, Default)]
415pub enum CompressionLevel {
416 Fastest,
418 Best,
420 #[default]
423 Default,
424 Precise(i32),
432}
433
434#[cfg(any(
435 feature = "compression-br",
436 feature = "compression-gzip",
437 feature = "compression-deflate",
438 feature = "compression-zstd"
439))]
440use async_compression::Level as AsyncCompressionLevel;
441
442#[cfg(any(
443 feature = "compression-br",
444 feature = "compression-gzip",
445 feature = "compression-deflate",
446 feature = "compression-zstd"
447))]
448impl CompressionLevel {
449 pub(crate) fn into_async_compression(self) -> AsyncCompressionLevel {
450 match self {
451 CompressionLevel::Fastest => AsyncCompressionLevel::Fastest,
452 CompressionLevel::Best => AsyncCompressionLevel::Best,
453 CompressionLevel::Default => AsyncCompressionLevel::Default,
454 CompressionLevel::Precise(quality) => AsyncCompressionLevel::Precise(quality),
455 }
456 }
457}