snix_glue/fetchers/
decompression.rs

1use std::{
2    io, mem,
3    pin::Pin,
4    task::{Context, Poll},
5};
6
7use async_compression::tokio::bufread::{BzDecoder, GzipDecoder, XzDecoder};
8use futures::ready;
9use pin_project::pin_project;
10use tokio::io::{AsyncBufRead, AsyncRead, BufReader, ReadBuf};
11
12const GZIP_MAGIC: [u8; 2] = [0x1f, 0x8b];
13const BZIP2_MAGIC: [u8; 3] = *b"BZh";
14const XZ_MAGIC: [u8; 6] = [0xfd, 0x37, 0x7a, 0x58, 0x5a, 0x00];
15const BYTES_NEEDED: usize = 6;
16
17#[derive(Debug, Clone, Copy)]
18enum Algorithm {
19    Gzip,
20    Bzip2,
21    Xz,
22}
23
24impl Algorithm {
25    fn from_magic(magic: &[u8]) -> Option<Self> {
26        if magic.starts_with(&GZIP_MAGIC) {
27            Some(Self::Gzip)
28        } else if magic.starts_with(&BZIP2_MAGIC) {
29            Some(Self::Bzip2)
30        } else if magic.starts_with(&XZ_MAGIC) {
31            Some(Self::Xz)
32        } else {
33            None
34        }
35    }
36}
37
38#[pin_project]
39struct WithPreexistingBuffer<R> {
40    buffer: Vec<u8>,
41    #[pin]
42    inner: R,
43}
44
45impl<R> AsyncRead for WithPreexistingBuffer<R>
46where
47    R: AsyncRead,
48{
49    fn poll_read(
50        self: Pin<&mut Self>,
51        cx: &mut Context<'_>,
52        buf: &mut ReadBuf<'_>,
53    ) -> Poll<io::Result<()>> {
54        let this = self.project();
55        if !this.buffer.is_empty() {
56            // TODO: check if the buffer fits first
57            buf.put_slice(this.buffer);
58            this.buffer.clear();
59        }
60        this.inner.poll_read(cx, buf)
61    }
62}
63
64#[pin_project(project = DecompressedReaderInnerProj)]
65enum DecompressedReaderInner<R> {
66    Unknown {
67        buffer: Vec<u8>,
68        #[pin]
69        inner: Option<R>,
70    },
71    Gzip(#[pin] GzipDecoder<BufReader<WithPreexistingBuffer<R>>>),
72    Bzip2(#[pin] BzDecoder<BufReader<WithPreexistingBuffer<R>>>),
73    Xz(#[pin] XzDecoder<BufReader<WithPreexistingBuffer<R>>>),
74}
75
76impl<R> DecompressedReaderInner<R>
77where
78    R: AsyncBufRead,
79{
80    fn switch_to(&mut self, algorithm: Algorithm) {
81        let (buffer, inner) = match self {
82            DecompressedReaderInner::Unknown { buffer, inner } => {
83                (mem::take(buffer), inner.take().unwrap())
84            }
85            DecompressedReaderInner::Gzip(_)
86            | DecompressedReaderInner::Bzip2(_)
87            | DecompressedReaderInner::Xz(_) => unreachable!(),
88        };
89        let inner = BufReader::new(WithPreexistingBuffer { buffer, inner });
90
91        *self = match algorithm {
92            Algorithm::Gzip => Self::Gzip(GzipDecoder::new(inner)),
93            Algorithm::Bzip2 => Self::Bzip2(BzDecoder::new(inner)),
94            Algorithm::Xz => Self::Xz(XzDecoder::new(inner)),
95        }
96    }
97}
98
99impl<R> AsyncRead for DecompressedReaderInner<R>
100where
101    R: AsyncBufRead,
102{
103    fn poll_read(
104        self: Pin<&mut Self>,
105        cx: &mut Context<'_>,
106        buf: &mut ReadBuf<'_>,
107    ) -> Poll<io::Result<()>> {
108        match self.project() {
109            DecompressedReaderInnerProj::Unknown { .. } => {
110                unreachable!("Can't call poll_read on Unknown")
111            }
112            DecompressedReaderInnerProj::Gzip(inner) => inner.poll_read(cx, buf),
113            DecompressedReaderInnerProj::Bzip2(inner) => inner.poll_read(cx, buf),
114            DecompressedReaderInnerProj::Xz(inner) => inner.poll_read(cx, buf),
115        }
116    }
117}
118
119#[pin_project]
120pub struct DecompressedReader<R> {
121    #[pin]
122    inner: DecompressedReaderInner<R>,
123    switch_to: Option<Algorithm>,
124}
125
126impl<R> DecompressedReader<R> {
127    pub fn new(inner: R) -> Self {
128        Self {
129            inner: DecompressedReaderInner::Unknown {
130                buffer: vec![0; BYTES_NEEDED],
131                inner: Some(inner),
132            },
133            switch_to: None,
134        }
135    }
136}
137
138impl<R> AsyncRead for DecompressedReader<R>
139where
140    R: AsyncBufRead + Unpin,
141{
142    fn poll_read(
143        self: Pin<&mut Self>,
144        cx: &mut Context<'_>,
145        buf: &mut ReadBuf<'_>,
146    ) -> Poll<io::Result<()>> {
147        let mut this = self.project();
148        let (buffer, inner) = match this.inner.as_mut().project() {
149            DecompressedReaderInnerProj::Gzip(inner) => return inner.poll_read(cx, buf),
150            DecompressedReaderInnerProj::Bzip2(inner) => return inner.poll_read(cx, buf),
151            DecompressedReaderInnerProj::Xz(inner) => return inner.poll_read(cx, buf),
152            DecompressedReaderInnerProj::Unknown { buffer, inner } => (buffer, inner),
153        };
154
155        let mut our_buf = ReadBuf::new(buffer);
156        ready!(inner.as_pin_mut().unwrap().poll_read(cx, &mut our_buf))?;
157
158        let data = our_buf.filled();
159        if data.len() >= BYTES_NEEDED {
160            if let Some(algorithm) = Algorithm::from_magic(data) {
161                this.inner.as_mut().switch_to(algorithm);
162            } else {
163                return Poll::Ready(Err(io::Error::new(
164                    io::ErrorKind::InvalidData,
165                    "tar data not gz, bzip2, or xz compressed",
166                )));
167            }
168            this.inner.poll_read(cx, buf)
169        } else {
170            cx.waker().wake_by_ref();
171            Poll::Pending
172        }
173    }
174}
175
176#[cfg(test)]
177mod tests {
178    use std::path::Path;
179
180    use async_compression::tokio::bufread::GzipEncoder;
181    use futures::TryStreamExt;
182    use rstest::rstest;
183    use tokio::io::{AsyncReadExt, BufReader};
184    use tokio_tar::Archive;
185
186    use super::*;
187
188    #[tokio::test]
189    async fn gzip() {
190        let data = b"abcdefghijk";
191        let mut enc = GzipEncoder::new(&data[..]);
192        let mut gzipped = vec![];
193        enc.read_to_end(&mut gzipped).await.unwrap();
194
195        let mut reader = DecompressedReader::new(BufReader::new(&gzipped[..]));
196        let mut round_tripped = vec![];
197        reader.read_to_end(&mut round_tripped).await.unwrap();
198
199        assert_eq!(data[..], round_tripped[..]);
200    }
201
202    #[rstest]
203    #[case::gzip(include_bytes!("../tests/blob.tar.gz"))]
204    #[case::bzip2(include_bytes!("../tests/blob.tar.bz2"))]
205    #[case::xz(include_bytes!("../tests/blob.tar.xz"))]
206    #[tokio::test]
207    async fn compressed_tar(#[case] data: &[u8]) {
208        let reader = DecompressedReader::new(BufReader::new(data));
209        let mut archive = Archive::new(reader);
210        let mut entries: Vec<_> = archive.entries().unwrap().try_collect().await.unwrap();
211
212        assert_eq!(entries.len(), 1);
213        assert_eq!(entries[0].path().unwrap().as_ref(), Path::new("empty"));
214        let mut data = String::new();
215        entries[0].read_to_string(&mut data).await.unwrap();
216        assert_eq!(data, "");
217    }
218}