1pub mod predicate;
73
74mod body;
75mod future;
76mod layer;
77mod pin_project_cfg;
78mod service;
79
80#[doc(inline)]
81pub use self::{
82 body::CompressionBody,
83 future::ResponseFuture,
84 layer::CompressionLayer,
85 predicate::{DefaultPredicate, Predicate},
86 service::Compression,
87};
88pub use crate::compression_utils::CompressionLevel;
89
90#[cfg(test)]
91mod tests {
92 use crate::compression::predicate::SizeAbove;
93
94 use super::*;
95 use crate::test_helpers::{Body, WithTrailers};
96 use async_compression::tokio::write::{BrotliDecoder, BrotliEncoder};
97 use flate2::read::GzDecoder;
98 use http::header::{
99 ACCEPT_ENCODING, ACCEPT_RANGES, CONTENT_ENCODING, CONTENT_RANGE, CONTENT_TYPE, RANGE,
100 };
101 use http::{HeaderMap, HeaderName, HeaderValue, Request, Response};
102 use http_body::Body as _;
103 use http_body_util::BodyExt;
104 use std::convert::Infallible;
105 use std::io::Read;
106 use std::sync::{Arc, RwLock};
107 use tokio::io::{AsyncReadExt, AsyncWriteExt};
108 use tokio_util::io::StreamReader;
109 use tower::{service_fn, Service, ServiceExt};
110
111 #[derive(Clone)]
113 struct Always;
114
115 impl Predicate for Always {
116 fn should_compress<B>(&self, _: &http::Response<B>) -> bool
117 where
118 B: http_body::Body,
119 {
120 true
121 }
122 }
123
124 #[tokio::test]
125 async fn gzip_works() {
126 let svc = service_fn(handle);
127 let mut svc = Compression::new(svc).compress_when(Always);
128
129 let req = Request::builder()
131 .header("accept-encoding", "gzip")
132 .body(Body::empty())
133 .unwrap();
134 let res = svc.ready().await.unwrap().call(req).await.unwrap();
135
136 let collected = res.into_body().collect().await.unwrap();
138 let trailers = collected.trailers().cloned().unwrap();
139 let compressed_data = collected.to_bytes();
140
141 let mut decoder = GzDecoder::new(&compressed_data[..]);
145 let mut decompressed = String::new();
146 decoder.read_to_string(&mut decompressed).unwrap();
147
148 assert_eq!(decompressed, "Hello, World!");
149
150 assert_eq!(trailers["foo"], "bar");
152 }
153
154 #[tokio::test]
155 async fn x_gzip_works() {
156 let svc = service_fn(handle);
157 let mut svc = Compression::new(svc).compress_when(Always);
158
159 let req = Request::builder()
161 .header("accept-encoding", "x-gzip")
162 .body(Body::empty())
163 .unwrap();
164 let res = svc.ready().await.unwrap().call(req).await.unwrap();
165
166 assert_eq!(
169 res.headers()
170 .get_all("content-encoding")
171 .iter()
172 .collect::<Vec<&HeaderValue>>(),
173 vec!(HeaderValue::from_static("gzip"))
174 );
175
176 let collected = res.into_body().collect().await.unwrap();
178 let trailers = collected.trailers().cloned().unwrap();
179 let compressed_data = collected.to_bytes();
180
181 let mut decoder = GzDecoder::new(&compressed_data[..]);
185 let mut decompressed = String::new();
186 decoder.read_to_string(&mut decompressed).unwrap();
187
188 assert_eq!(decompressed, "Hello, World!");
189
190 assert_eq!(trailers["foo"], "bar");
192 }
193
194 #[tokio::test]
195 async fn zstd_works() {
196 let svc = service_fn(handle);
197 let mut svc = Compression::new(svc).compress_when(Always);
198
199 let req = Request::builder()
201 .header("accept-encoding", "zstd")
202 .body(Body::empty())
203 .unwrap();
204 let res = svc.ready().await.unwrap().call(req).await.unwrap();
205
206 let body = res.into_body();
208 let compressed_data = body.collect().await.unwrap().to_bytes();
209
210 let decompressed = zstd::stream::decode_all(std::io::Cursor::new(compressed_data)).unwrap();
212 let decompressed = String::from_utf8(decompressed).unwrap();
213
214 assert_eq!(decompressed, "Hello, World!");
215 }
216
217 #[tokio::test]
218 async fn no_recompress() {
219 const DATA: &str = "Hello, World! I'm already compressed with br!";
220
221 let svc = service_fn(|_| async {
222 let buf = {
223 let mut buf = Vec::new();
224
225 let mut enc = BrotliEncoder::new(&mut buf);
226 enc.write_all(DATA.as_bytes()).await?;
227 enc.flush().await?;
228 buf
229 };
230
231 let resp = Response::builder()
232 .header("content-encoding", "br")
233 .body(Body::from(buf))
234 .unwrap();
235 Ok::<_, std::io::Error>(resp)
236 });
237 let mut svc = Compression::new(svc);
238
239 let req = Request::builder()
244 .header("accept-encoding", "gzip")
245 .body(Body::empty())
246 .unwrap();
247 let res = svc.ready().await.unwrap().call(req).await.unwrap();
248
249 assert_eq!(
251 res.headers()
252 .get("content-encoding")
253 .and_then(|h| h.to_str().ok())
254 .unwrap_or_default(),
255 "br",
256 );
257
258 let body = res.into_body();
260 let data = body.collect().await.unwrap().to_bytes();
261
262 let data = {
264 let mut output_buf = Vec::new();
265 let mut decoder = BrotliDecoder::new(&mut output_buf);
266 decoder
267 .write_all(&data)
268 .await
269 .expect("couldn't brotli-decode");
270 decoder.flush().await.expect("couldn't flush");
271 output_buf
272 };
273
274 assert_eq!(data, DATA.as_bytes());
275 }
276
277 async fn handle(_req: Request<Body>) -> Result<Response<WithTrailers<Body>>, Infallible> {
278 let mut trailers = HeaderMap::new();
279 trailers.insert(HeaderName::from_static("foo"), "bar".parse().unwrap());
280 let body = Body::from("Hello, World!").with_trailers(trailers);
281 Ok(Response::builder().body(body).unwrap())
282 }
283
284 #[tokio::test]
285 async fn will_not_compress_if_filtered_out() {
286 use predicate::Predicate;
287
288 const DATA: &str = "Hello world uncompressed";
289
290 let svc_fn = service_fn(|_| async {
291 let resp = Response::builder()
292 .body(Body::from(DATA.as_bytes()))
294 .unwrap();
295 Ok::<_, std::io::Error>(resp)
296 });
297
298 #[derive(Default, Clone)]
300 struct EveryOtherResponse(Arc<RwLock<u64>>);
301
302 #[allow(clippy::dbg_macro)]
303 impl Predicate for EveryOtherResponse {
304 fn should_compress<B>(&self, _: &http::Response<B>) -> bool
305 where
306 B: http_body::Body,
307 {
308 let mut guard = self.0.write().unwrap();
309 let should_compress = *guard % 2 != 0;
310 *guard += 1;
311 dbg!(should_compress)
312 }
313 }
314
315 let mut svc = Compression::new(svc_fn).compress_when(EveryOtherResponse::default());
316 let req = Request::builder()
317 .header("accept-encoding", "br")
318 .body(Body::empty())
319 .unwrap();
320 let res = svc.ready().await.unwrap().call(req).await.unwrap();
321
322 let body = res.into_body();
324 let data = body.collect().await.unwrap().to_bytes();
325 let still_uncompressed = String::from_utf8(data.to_vec()).unwrap();
326 assert_eq!(DATA, &still_uncompressed);
327
328 let req = Request::builder()
330 .header("accept-encoding", "br")
331 .body(Body::empty())
332 .unwrap();
333 let res = svc.ready().await.unwrap().call(req).await.unwrap();
334
335 let body = res.into_body();
337 let data = body.collect().await.unwrap().to_bytes();
338 assert!(String::from_utf8(data.to_vec()).is_err());
339 }
340
341 #[tokio::test]
342 async fn doesnt_compress_images() {
343 async fn handle(_req: Request<Body>) -> Result<Response<Body>, Infallible> {
344 let mut res = Response::new(Body::from(
345 "a".repeat((SizeAbove::DEFAULT_MIN_SIZE * 2) as usize),
346 ));
347 res.headers_mut()
348 .insert(CONTENT_TYPE, "image/png".parse().unwrap());
349 Ok(res)
350 }
351
352 let svc = Compression::new(service_fn(handle));
353
354 let res = svc
355 .oneshot(
356 Request::builder()
357 .header(ACCEPT_ENCODING, "gzip")
358 .body(Body::empty())
359 .unwrap(),
360 )
361 .await
362 .unwrap();
363 assert!(res.headers().get(CONTENT_ENCODING).is_none());
364 }
365
366 #[tokio::test]
367 async fn does_compress_svg() {
368 async fn handle(_req: Request<Body>) -> Result<Response<Body>, Infallible> {
369 let mut res = Response::new(Body::from(
370 "a".repeat((SizeAbove::DEFAULT_MIN_SIZE * 2) as usize),
371 ));
372 res.headers_mut()
373 .insert(CONTENT_TYPE, "image/svg+xml".parse().unwrap());
374 Ok(res)
375 }
376
377 let svc = Compression::new(service_fn(handle));
378
379 let res = svc
380 .oneshot(
381 Request::builder()
382 .header(ACCEPT_ENCODING, "gzip")
383 .body(Body::empty())
384 .unwrap(),
385 )
386 .await
387 .unwrap();
388 assert_eq!(res.headers()[CONTENT_ENCODING], "gzip");
389 }
390
391 #[tokio::test]
392 async fn compress_with_quality() {
393 const DATA: &str = "Check compression quality level! Check compression quality level! Check compression quality level!";
394 let level = CompressionLevel::Best;
395
396 let svc = service_fn(|_| async {
397 let resp = Response::builder()
398 .body(Body::from(DATA.as_bytes()))
399 .unwrap();
400 Ok::<_, std::io::Error>(resp)
401 });
402
403 let mut svc = Compression::new(svc).quality(level);
404
405 let req = Request::builder()
407 .header("accept-encoding", "br")
408 .body(Body::empty())
409 .unwrap();
410 let res = svc.ready().await.unwrap().call(req).await.unwrap();
411
412 let body = res.into_body();
414 let compressed_data = body.collect().await.unwrap().to_bytes();
415
416 let compressed_with_level = {
418 use async_compression::tokio::bufread::BrotliEncoder;
419
420 let stream = Box::pin(futures_util::stream::once(async move {
421 Ok::<_, std::io::Error>(DATA.as_bytes())
422 }));
423 let reader = StreamReader::new(stream);
424 let mut enc = BrotliEncoder::with_quality(reader, level.into_async_compression());
425
426 let mut buf = Vec::new();
427 enc.read_to_end(&mut buf).await.unwrap();
428 buf
429 };
430
431 assert_eq!(
432 compressed_data,
433 compressed_with_level.as_slice(),
434 "Compression level is not respected"
435 );
436 }
437
438 #[tokio::test]
439 async fn should_not_compress_ranges() {
440 let svc = service_fn(|_| async {
441 let mut res = Response::new(Body::from("Hello"));
442 let headers = res.headers_mut();
443 headers.insert(ACCEPT_RANGES, "bytes".parse().unwrap());
444 headers.insert(CONTENT_RANGE, "bytes 0-4/*".parse().unwrap());
445 Ok::<_, std::io::Error>(res)
446 });
447 let mut svc = Compression::new(svc).compress_when(Always);
448
449 let req = Request::builder()
451 .header(ACCEPT_ENCODING, "gzip")
452 .header(RANGE, "bytes=0-4")
453 .body(Body::empty())
454 .unwrap();
455 let res = svc.ready().await.unwrap().call(req).await.unwrap();
456 let headers = res.headers().clone();
457
458 let collected = res.into_body().collect().await.unwrap().to_bytes();
460
461 assert_eq!(headers[ACCEPT_RANGES], "bytes");
462 assert!(!headers.contains_key(CONTENT_ENCODING));
463 assert_eq!(collected, "Hello");
464 }
465
466 #[tokio::test]
467 async fn should_strip_accept_ranges_header_when_compressing() {
468 let svc = service_fn(|_| async {
469 let mut res = Response::new(Body::from("Hello, World!"));
470 res.headers_mut()
471 .insert(ACCEPT_RANGES, "bytes".parse().unwrap());
472 Ok::<_, std::io::Error>(res)
473 });
474 let mut svc = Compression::new(svc).compress_when(Always);
475
476 let req = Request::builder()
478 .header(ACCEPT_ENCODING, "gzip")
479 .body(Body::empty())
480 .unwrap();
481 let res = svc.ready().await.unwrap().call(req).await.unwrap();
482 let headers = res.headers().clone();
483
484 let collected = res.into_body().collect().await.unwrap();
486 let compressed_data = collected.to_bytes();
487
488 let mut decoder = GzDecoder::new(&compressed_data[..]);
492 let mut decompressed = String::new();
493 decoder.read_to_string(&mut decompressed).unwrap();
494
495 assert!(!headers.contains_key(ACCEPT_RANGES));
496 assert_eq!(headers[CONTENT_ENCODING], "gzip");
497 assert_eq!(decompressed, "Hello, World!");
498 }
499
500 #[tokio::test]
501 async fn size_hint_identity() {
502 let msg = "Hello, world!";
503 let svc = service_fn(|_| async { Ok::<_, std::io::Error>(Response::new(Body::from(msg))) });
504 let mut svc = Compression::new(svc);
505
506 let req = Request::new(Body::empty());
507 let res = svc.ready().await.unwrap().call(req).await.unwrap();
508 let body = res.into_body();
509 assert_eq!(body.size_hint().exact().unwrap(), msg.len() as u64);
510 }
511}