snix_glue/fetchers/
decompression.rs1use std::{
2 io, mem,
3 pin::Pin,
4 task::{Context, Poll},
5};
6
7use async_compression::tokio::bufread::{BzDecoder, GzipDecoder, XzDecoder};
8use futures::ready;
9use pin_project::pin_project;
10use tokio::io::{AsyncBufRead, AsyncRead, BufReader, ReadBuf};
11
12const GZIP_MAGIC: [u8; 2] = [0x1f, 0x8b];
13const BZIP2_MAGIC: [u8; 3] = *b"BZh";
14const XZ_MAGIC: [u8; 6] = [0xfd, 0x37, 0x7a, 0x58, 0x5a, 0x00];
15const BYTES_NEEDED: usize = 6;
16
17#[derive(Debug, Clone, Copy)]
18enum Algorithm {
19 Gzip,
20 Bzip2,
21 Xz,
22}
23
24impl Algorithm {
25 fn from_magic(magic: &[u8]) -> Option<Self> {
26 if magic.starts_with(&GZIP_MAGIC) {
27 Some(Self::Gzip)
28 } else if magic.starts_with(&BZIP2_MAGIC) {
29 Some(Self::Bzip2)
30 } else if magic.starts_with(&XZ_MAGIC) {
31 Some(Self::Xz)
32 } else {
33 None
34 }
35 }
36}
37
38#[pin_project]
39struct WithPreexistingBuffer<R> {
40 buffer: Vec<u8>,
41 #[pin]
42 inner: R,
43}
44
45impl<R> AsyncRead for WithPreexistingBuffer<R>
46where
47 R: AsyncRead,
48{
49 fn poll_read(
50 self: Pin<&mut Self>,
51 cx: &mut Context<'_>,
52 buf: &mut ReadBuf<'_>,
53 ) -> Poll<io::Result<()>> {
54 let this = self.project();
55 if !this.buffer.is_empty() {
56 buf.put_slice(this.buffer);
58 this.buffer.clear();
59 }
60 this.inner.poll_read(cx, buf)
61 }
62}
63
64#[pin_project(project = DecompressedReaderInnerProj)]
65enum DecompressedReaderInner<R> {
66 Unknown {
67 buffer: Vec<u8>,
68 #[pin]
69 inner: Option<R>,
70 },
71 Gzip(#[pin] GzipDecoder<BufReader<WithPreexistingBuffer<R>>>),
72 Bzip2(#[pin] BzDecoder<BufReader<WithPreexistingBuffer<R>>>),
73 Xz(#[pin] XzDecoder<BufReader<WithPreexistingBuffer<R>>>),
74}
75
76impl<R> DecompressedReaderInner<R>
77where
78 R: AsyncBufRead,
79{
80 fn switch_to(&mut self, algorithm: Algorithm) {
81 let (buffer, inner) = match self {
82 DecompressedReaderInner::Unknown { buffer, inner } => {
83 (mem::take(buffer), inner.take().unwrap())
84 }
85 DecompressedReaderInner::Gzip(_)
86 | DecompressedReaderInner::Bzip2(_)
87 | DecompressedReaderInner::Xz(_) => unreachable!(),
88 };
89 let inner = BufReader::new(WithPreexistingBuffer { buffer, inner });
90
91 *self = match algorithm {
92 Algorithm::Gzip => Self::Gzip(GzipDecoder::new(inner)),
93 Algorithm::Bzip2 => Self::Bzip2(BzDecoder::new(inner)),
94 Algorithm::Xz => Self::Xz(XzDecoder::new(inner)),
95 }
96 }
97}
98
99impl<R> AsyncRead for DecompressedReaderInner<R>
100where
101 R: AsyncBufRead,
102{
103 fn poll_read(
104 self: Pin<&mut Self>,
105 cx: &mut Context<'_>,
106 buf: &mut ReadBuf<'_>,
107 ) -> Poll<io::Result<()>> {
108 match self.project() {
109 DecompressedReaderInnerProj::Unknown { .. } => {
110 unreachable!("Can't call poll_read on Unknown")
111 }
112 DecompressedReaderInnerProj::Gzip(inner) => inner.poll_read(cx, buf),
113 DecompressedReaderInnerProj::Bzip2(inner) => inner.poll_read(cx, buf),
114 DecompressedReaderInnerProj::Xz(inner) => inner.poll_read(cx, buf),
115 }
116 }
117}
118
119#[pin_project]
120pub struct DecompressedReader<R> {
121 #[pin]
122 inner: DecompressedReaderInner<R>,
123 switch_to: Option<Algorithm>,
124}
125
126impl<R> DecompressedReader<R> {
127 pub fn new(inner: R) -> Self {
128 Self {
129 inner: DecompressedReaderInner::Unknown {
130 buffer: vec![0; BYTES_NEEDED],
131 inner: Some(inner),
132 },
133 switch_to: None,
134 }
135 }
136}
137
138impl<R> AsyncRead for DecompressedReader<R>
139where
140 R: AsyncBufRead + Unpin,
141{
142 fn poll_read(
143 self: Pin<&mut Self>,
144 cx: &mut Context<'_>,
145 buf: &mut ReadBuf<'_>,
146 ) -> Poll<io::Result<()>> {
147 let mut this = self.project();
148 let (buffer, inner) = match this.inner.as_mut().project() {
149 DecompressedReaderInnerProj::Gzip(inner) => return inner.poll_read(cx, buf),
150 DecompressedReaderInnerProj::Bzip2(inner) => return inner.poll_read(cx, buf),
151 DecompressedReaderInnerProj::Xz(inner) => return inner.poll_read(cx, buf),
152 DecompressedReaderInnerProj::Unknown { buffer, inner } => (buffer, inner),
153 };
154
155 let mut our_buf = ReadBuf::new(buffer);
156 ready!(inner.as_pin_mut().unwrap().poll_read(cx, &mut our_buf))?;
157
158 let data = our_buf.filled();
159 if data.len() >= BYTES_NEEDED {
160 if let Some(algorithm) = Algorithm::from_magic(data) {
161 this.inner.as_mut().switch_to(algorithm);
162 } else {
163 return Poll::Ready(Err(io::Error::new(
164 io::ErrorKind::InvalidData,
165 "tar data not gz, bzip2, or xz compressed",
166 )));
167 }
168 this.inner.poll_read(cx, buf)
169 } else {
170 cx.waker().wake_by_ref();
171 Poll::Pending
172 }
173 }
174}
175
176#[cfg(test)]
177mod tests {
178 use std::path::Path;
179
180 use async_compression::tokio::bufread::GzipEncoder;
181 use futures::TryStreamExt;
182 use rstest::rstest;
183 use tokio::io::{AsyncReadExt, BufReader};
184 use tokio_tar::Archive;
185
186 use super::*;
187
188 #[tokio::test]
189 async fn gzip() {
190 let data = b"abcdefghijk";
191 let mut enc = GzipEncoder::new(&data[..]);
192 let mut gzipped = vec![];
193 enc.read_to_end(&mut gzipped).await.unwrap();
194
195 let mut reader = DecompressedReader::new(BufReader::new(&gzipped[..]));
196 let mut round_tripped = vec![];
197 reader.read_to_end(&mut round_tripped).await.unwrap();
198
199 assert_eq!(data[..], round_tripped[..]);
200 }
201
202 #[rstest]
203 #[case::gzip(include_bytes!("../tests/blob.tar.gz"))]
204 #[case::bzip2(include_bytes!("../tests/blob.tar.bz2"))]
205 #[case::xz(include_bytes!("../tests/blob.tar.xz"))]
206 #[tokio::test]
207 async fn compressed_tar(#[case] data: &[u8]) {
208 let reader = DecompressedReader::new(BufReader::new(data));
209 let mut archive = Archive::new(reader);
210 let mut entries: Vec<_> = archive.entries().unwrap().try_collect().await.unwrap();
211
212 assert_eq!(entries.len(), 1);
213 assert_eq!(entries[0].path().unwrap().as_ref(), Path::new("empty"));
214 let mut data = String::new();
215 entries[0].read_to_string(&mut data).await.unwrap();
216 assert_eq!(data, "");
217 }
218}