tower_http/compression/
mod.rs

1//! Middleware that compresses response bodies.
2//!
3//! # Example
4//!
5//! Example showing how to respond with the compressed contents of a file.
6//!
7//! ```rust
8//! use bytes::{Bytes, BytesMut};
9//! use http::{Request, Response, header::ACCEPT_ENCODING};
10//! use http_body_util::{Full, BodyExt, StreamBody, combinators::UnsyncBoxBody};
11//! use http_body::Frame;
12//! use std::convert::Infallible;
13//! use tokio::fs::{self, File};
14//! use tokio_util::io::ReaderStream;
15//! use tower::{Service, ServiceExt, ServiceBuilder, service_fn};
16//! use tower_http::{compression::CompressionLayer, BoxError};
17//! use futures_util::TryStreamExt;
18//!
19//! type BoxBody = UnsyncBoxBody<Bytes, std::io::Error>;
20//!
21//! # #[tokio::main]
22//! # async fn main() -> Result<(), BoxError> {
23//! async fn handle(req: Request<Full<Bytes>>) -> Result<Response<BoxBody>, Infallible> {
24//!     // Open the file.
25//!     let file = File::open("Cargo.toml").await.expect("file missing");
26//!     // Convert the file into a `Stream` of `Bytes`.
27//!     let stream = ReaderStream::new(file);
28//!     // Convert the stream into a stream of data `Frame`s.
29//!     let stream = stream.map_ok(Frame::data);
30//!     // Convert the `Stream` into a `Body`.
31//!     let body = StreamBody::new(stream);
32//!     // Erase the type because its very hard to name in the function signature.
33//!     let body = body.boxed_unsync();
34//!     // Create response.
35//!     Ok(Response::new(body))
36//! }
37//!
38//! let mut service = ServiceBuilder::new()
39//!     // Compress responses based on the `Accept-Encoding` header.
40//!     .layer(CompressionLayer::new())
41//!     .service_fn(handle);
42//!
43//! // Call the service.
44//! let request = Request::builder()
45//!     .header(ACCEPT_ENCODING, "gzip")
46//!     .body(Full::<Bytes>::default())?;
47//!
48//! let response = service
49//!     .ready()
50//!     .await?
51//!     .call(request)
52//!     .await?;
53//!
54//! assert_eq!(response.headers()["content-encoding"], "gzip");
55//!
56//! // Read the body
57//! let bytes = response
58//!     .into_body()
59//!     .collect()
60//!     .await?
61//!     .to_bytes();
62//!
63//! // The compressed body should be smaller 🤞
64//! let uncompressed_len = fs::read_to_string("Cargo.toml").await?.len();
65//! assert!(bytes.len() < uncompressed_len);
66//! #
67//! # Ok(())
68//! # }
69//! ```
70//!
71
72pub mod predicate;
73
74mod body;
75mod future;
76mod layer;
77mod pin_project_cfg;
78mod service;
79
80#[doc(inline)]
81pub use self::{
82    body::CompressionBody,
83    future::ResponseFuture,
84    layer::CompressionLayer,
85    predicate::{DefaultPredicate, Predicate},
86    service::Compression,
87};
88pub use crate::compression_utils::CompressionLevel;
89
90#[cfg(test)]
91mod tests {
92    use crate::compression::predicate::SizeAbove;
93
94    use super::*;
95    use crate::test_helpers::{Body, WithTrailers};
96    use async_compression::tokio::write::{BrotliDecoder, BrotliEncoder};
97    use flate2::read::GzDecoder;
98    use http::header::{
99        ACCEPT_ENCODING, ACCEPT_RANGES, CONTENT_ENCODING, CONTENT_RANGE, CONTENT_TYPE, RANGE,
100    };
101    use http::{HeaderMap, HeaderName, HeaderValue, Request, Response};
102    use http_body::Body as _;
103    use http_body_util::BodyExt;
104    use std::convert::Infallible;
105    use std::io::Read;
106    use std::sync::{Arc, RwLock};
107    use tokio::io::{AsyncReadExt, AsyncWriteExt};
108    use tokio_util::io::StreamReader;
109    use tower::{service_fn, Service, ServiceExt};
110
111    // Compression filter allows every other request to be compressed
112    #[derive(Clone)]
113    struct Always;
114
115    impl Predicate for Always {
116        fn should_compress<B>(&self, _: &http::Response<B>) -> bool
117        where
118            B: http_body::Body,
119        {
120            true
121        }
122    }
123
124    #[tokio::test]
125    async fn gzip_works() {
126        let svc = service_fn(handle);
127        let mut svc = Compression::new(svc).compress_when(Always);
128
129        // call the service
130        let req = Request::builder()
131            .header("accept-encoding", "gzip")
132            .body(Body::empty())
133            .unwrap();
134        let res = svc.ready().await.unwrap().call(req).await.unwrap();
135
136        // read the compressed body
137        let collected = res.into_body().collect().await.unwrap();
138        let trailers = collected.trailers().cloned().unwrap();
139        let compressed_data = collected.to_bytes();
140
141        // decompress the body
142        // doing this with flate2 as that is much easier than async-compression and blocking during
143        // tests is fine
144        let mut decoder = GzDecoder::new(&compressed_data[..]);
145        let mut decompressed = String::new();
146        decoder.read_to_string(&mut decompressed).unwrap();
147
148        assert_eq!(decompressed, "Hello, World!");
149
150        // trailers are maintained
151        assert_eq!(trailers["foo"], "bar");
152    }
153
154    #[tokio::test]
155    async fn x_gzip_works() {
156        let svc = service_fn(handle);
157        let mut svc = Compression::new(svc).compress_when(Always);
158
159        // call the service
160        let req = Request::builder()
161            .header("accept-encoding", "x-gzip")
162            .body(Body::empty())
163            .unwrap();
164        let res = svc.ready().await.unwrap().call(req).await.unwrap();
165
166        // we treat x-gzip as equivalent to gzip and don't have to return x-gzip
167        // taking extra caution by checking all headers with this name
168        assert_eq!(
169            res.headers()
170                .get_all("content-encoding")
171                .iter()
172                .collect::<Vec<&HeaderValue>>(),
173            vec!(HeaderValue::from_static("gzip"))
174        );
175
176        // read the compressed body
177        let collected = res.into_body().collect().await.unwrap();
178        let trailers = collected.trailers().cloned().unwrap();
179        let compressed_data = collected.to_bytes();
180
181        // decompress the body
182        // doing this with flate2 as that is much easier than async-compression and blocking during
183        // tests is fine
184        let mut decoder = GzDecoder::new(&compressed_data[..]);
185        let mut decompressed = String::new();
186        decoder.read_to_string(&mut decompressed).unwrap();
187
188        assert_eq!(decompressed, "Hello, World!");
189
190        // trailers are maintained
191        assert_eq!(trailers["foo"], "bar");
192    }
193
194    #[tokio::test]
195    async fn zstd_works() {
196        let svc = service_fn(handle);
197        let mut svc = Compression::new(svc).compress_when(Always);
198
199        // call the service
200        let req = Request::builder()
201            .header("accept-encoding", "zstd")
202            .body(Body::empty())
203            .unwrap();
204        let res = svc.ready().await.unwrap().call(req).await.unwrap();
205
206        // read the compressed body
207        let body = res.into_body();
208        let compressed_data = body.collect().await.unwrap().to_bytes();
209
210        // decompress the body
211        let decompressed = zstd::stream::decode_all(std::io::Cursor::new(compressed_data)).unwrap();
212        let decompressed = String::from_utf8(decompressed).unwrap();
213
214        assert_eq!(decompressed, "Hello, World!");
215    }
216
217    #[tokio::test]
218    async fn no_recompress() {
219        const DATA: &str = "Hello, World! I'm already compressed with br!";
220
221        let svc = service_fn(|_| async {
222            let buf = {
223                let mut buf = Vec::new();
224
225                let mut enc = BrotliEncoder::new(&mut buf);
226                enc.write_all(DATA.as_bytes()).await?;
227                enc.flush().await?;
228                buf
229            };
230
231            let resp = Response::builder()
232                .header("content-encoding", "br")
233                .body(Body::from(buf))
234                .unwrap();
235            Ok::<_, std::io::Error>(resp)
236        });
237        let mut svc = Compression::new(svc);
238
239        // call the service
240        //
241        // note: the accept-encoding doesn't match the content-encoding above, so that
242        // we're able to see if the compression layer triggered or not
243        let req = Request::builder()
244            .header("accept-encoding", "gzip")
245            .body(Body::empty())
246            .unwrap();
247        let res = svc.ready().await.unwrap().call(req).await.unwrap();
248
249        // check we didn't recompress
250        assert_eq!(
251            res.headers()
252                .get("content-encoding")
253                .and_then(|h| h.to_str().ok())
254                .unwrap_or_default(),
255            "br",
256        );
257
258        // read the compressed body
259        let body = res.into_body();
260        let data = body.collect().await.unwrap().to_bytes();
261
262        // decompress the body
263        let data = {
264            let mut output_buf = Vec::new();
265            let mut decoder = BrotliDecoder::new(&mut output_buf);
266            decoder
267                .write_all(&data)
268                .await
269                .expect("couldn't brotli-decode");
270            decoder.flush().await.expect("couldn't flush");
271            output_buf
272        };
273
274        assert_eq!(data, DATA.as_bytes());
275    }
276
277    async fn handle(_req: Request<Body>) -> Result<Response<WithTrailers<Body>>, Infallible> {
278        let mut trailers = HeaderMap::new();
279        trailers.insert(HeaderName::from_static("foo"), "bar".parse().unwrap());
280        let body = Body::from("Hello, World!").with_trailers(trailers);
281        Ok(Response::builder().body(body).unwrap())
282    }
283
284    #[tokio::test]
285    async fn will_not_compress_if_filtered_out() {
286        use predicate::Predicate;
287
288        const DATA: &str = "Hello world uncompressed";
289
290        let svc_fn = service_fn(|_| async {
291            let resp = Response::builder()
292                // .header("content-encoding", "br")
293                .body(Body::from(DATA.as_bytes()))
294                .unwrap();
295            Ok::<_, std::io::Error>(resp)
296        });
297
298        // Compression filter allows every other request to be compressed
299        #[derive(Default, Clone)]
300        struct EveryOtherResponse(Arc<RwLock<u64>>);
301
302        #[allow(clippy::dbg_macro)]
303        impl Predicate for EveryOtherResponse {
304            fn should_compress<B>(&self, _: &http::Response<B>) -> bool
305            where
306                B: http_body::Body,
307            {
308                let mut guard = self.0.write().unwrap();
309                let should_compress = *guard % 2 != 0;
310                *guard += 1;
311                dbg!(should_compress)
312            }
313        }
314
315        let mut svc = Compression::new(svc_fn).compress_when(EveryOtherResponse::default());
316        let req = Request::builder()
317            .header("accept-encoding", "br")
318            .body(Body::empty())
319            .unwrap();
320        let res = svc.ready().await.unwrap().call(req).await.unwrap();
321
322        // read the uncompressed body
323        let body = res.into_body();
324        let data = body.collect().await.unwrap().to_bytes();
325        let still_uncompressed = String::from_utf8(data.to_vec()).unwrap();
326        assert_eq!(DATA, &still_uncompressed);
327
328        // Compression filter will compress the next body
329        let req = Request::builder()
330            .header("accept-encoding", "br")
331            .body(Body::empty())
332            .unwrap();
333        let res = svc.ready().await.unwrap().call(req).await.unwrap();
334
335        // read the compressed body
336        let body = res.into_body();
337        let data = body.collect().await.unwrap().to_bytes();
338        assert!(String::from_utf8(data.to_vec()).is_err());
339    }
340
341    #[tokio::test]
342    async fn doesnt_compress_images() {
343        async fn handle(_req: Request<Body>) -> Result<Response<Body>, Infallible> {
344            let mut res = Response::new(Body::from(
345                "a".repeat((SizeAbove::DEFAULT_MIN_SIZE * 2) as usize),
346            ));
347            res.headers_mut()
348                .insert(CONTENT_TYPE, "image/png".parse().unwrap());
349            Ok(res)
350        }
351
352        let svc = Compression::new(service_fn(handle));
353
354        let res = svc
355            .oneshot(
356                Request::builder()
357                    .header(ACCEPT_ENCODING, "gzip")
358                    .body(Body::empty())
359                    .unwrap(),
360            )
361            .await
362            .unwrap();
363        assert!(res.headers().get(CONTENT_ENCODING).is_none());
364    }
365
366    #[tokio::test]
367    async fn does_compress_svg() {
368        async fn handle(_req: Request<Body>) -> Result<Response<Body>, Infallible> {
369            let mut res = Response::new(Body::from(
370                "a".repeat((SizeAbove::DEFAULT_MIN_SIZE * 2) as usize),
371            ));
372            res.headers_mut()
373                .insert(CONTENT_TYPE, "image/svg+xml".parse().unwrap());
374            Ok(res)
375        }
376
377        let svc = Compression::new(service_fn(handle));
378
379        let res = svc
380            .oneshot(
381                Request::builder()
382                    .header(ACCEPT_ENCODING, "gzip")
383                    .body(Body::empty())
384                    .unwrap(),
385            )
386            .await
387            .unwrap();
388        assert_eq!(res.headers()[CONTENT_ENCODING], "gzip");
389    }
390
391    #[tokio::test]
392    async fn compress_with_quality() {
393        const DATA: &str = "Check compression quality level! Check compression quality level! Check compression quality level!";
394        let level = CompressionLevel::Best;
395
396        let svc = service_fn(|_| async {
397            let resp = Response::builder()
398                .body(Body::from(DATA.as_bytes()))
399                .unwrap();
400            Ok::<_, std::io::Error>(resp)
401        });
402
403        let mut svc = Compression::new(svc).quality(level);
404
405        // call the service
406        let req = Request::builder()
407            .header("accept-encoding", "br")
408            .body(Body::empty())
409            .unwrap();
410        let res = svc.ready().await.unwrap().call(req).await.unwrap();
411
412        // read the compressed body
413        let body = res.into_body();
414        let compressed_data = body.collect().await.unwrap().to_bytes();
415
416        // build the compressed body with the same quality level
417        let compressed_with_level = {
418            use async_compression::tokio::bufread::BrotliEncoder;
419
420            let stream = Box::pin(futures_util::stream::once(async move {
421                Ok::<_, std::io::Error>(DATA.as_bytes())
422            }));
423            let reader = StreamReader::new(stream);
424            let mut enc = BrotliEncoder::with_quality(reader, level.into_async_compression());
425
426            let mut buf = Vec::new();
427            enc.read_to_end(&mut buf).await.unwrap();
428            buf
429        };
430
431        assert_eq!(
432            compressed_data,
433            compressed_with_level.as_slice(),
434            "Compression level is not respected"
435        );
436    }
437
438    #[tokio::test]
439    async fn should_not_compress_ranges() {
440        let svc = service_fn(|_| async {
441            let mut res = Response::new(Body::from("Hello"));
442            let headers = res.headers_mut();
443            headers.insert(ACCEPT_RANGES, "bytes".parse().unwrap());
444            headers.insert(CONTENT_RANGE, "bytes 0-4/*".parse().unwrap());
445            Ok::<_, std::io::Error>(res)
446        });
447        let mut svc = Compression::new(svc).compress_when(Always);
448
449        // call the service
450        let req = Request::builder()
451            .header(ACCEPT_ENCODING, "gzip")
452            .header(RANGE, "bytes=0-4")
453            .body(Body::empty())
454            .unwrap();
455        let res = svc.ready().await.unwrap().call(req).await.unwrap();
456        let headers = res.headers().clone();
457
458        // read the uncompressed body
459        let collected = res.into_body().collect().await.unwrap().to_bytes();
460
461        assert_eq!(headers[ACCEPT_RANGES], "bytes");
462        assert!(!headers.contains_key(CONTENT_ENCODING));
463        assert_eq!(collected, "Hello");
464    }
465
466    #[tokio::test]
467    async fn should_strip_accept_ranges_header_when_compressing() {
468        let svc = service_fn(|_| async {
469            let mut res = Response::new(Body::from("Hello, World!"));
470            res.headers_mut()
471                .insert(ACCEPT_RANGES, "bytes".parse().unwrap());
472            Ok::<_, std::io::Error>(res)
473        });
474        let mut svc = Compression::new(svc).compress_when(Always);
475
476        // call the service
477        let req = Request::builder()
478            .header(ACCEPT_ENCODING, "gzip")
479            .body(Body::empty())
480            .unwrap();
481        let res = svc.ready().await.unwrap().call(req).await.unwrap();
482        let headers = res.headers().clone();
483
484        // read the compressed body
485        let collected = res.into_body().collect().await.unwrap();
486        let compressed_data = collected.to_bytes();
487
488        // decompress the body
489        // doing this with flate2 as that is much easier than async-compression and blocking during
490        // tests is fine
491        let mut decoder = GzDecoder::new(&compressed_data[..]);
492        let mut decompressed = String::new();
493        decoder.read_to_string(&mut decompressed).unwrap();
494
495        assert!(!headers.contains_key(ACCEPT_RANGES));
496        assert_eq!(headers[CONTENT_ENCODING], "gzip");
497        assert_eq!(decompressed, "Hello, World!");
498    }
499
500    #[tokio::test]
501    async fn size_hint_identity() {
502        let msg = "Hello, world!";
503        let svc = service_fn(|_| async { Ok::<_, std::io::Error>(Response::new(Body::from(msg))) });
504        let mut svc = Compression::new(svc);
505
506        let req = Request::new(Body::empty());
507        let res = svc.ready().await.unwrap().call(req).await.unwrap();
508        let body = res.into_body();
509        assert_eq!(body.size_hint().exact().unwrap(), msg.len() as u64);
510    }
511}