Skip to main content

snix_store/
decompression.rs

1use std::{
2    io::{self, Cursor},
3    pin::Pin,
4    task::{Context, Poll},
5};
6
7use async_compression::tokio::bufread::{BzDecoder, GzipDecoder, XzDecoder, ZstdDecoder};
8use pin_project::pin_project;
9use tokio::io::{AsyncBufRead, AsyncRead, AsyncReadExt, ReadBuf};
10
11const GZIP_MAGIC: [u8; 2] = [0x1f, 0x8b];
12const BZIP2_MAGIC: [u8; 3] = *b"BZh";
13const XZ_MAGIC: [u8; 6] = [0xfd, 0x37, 0x7a, 0x58, 0x5a, 0x00];
14const ZSTD_MAGIC: [u8; 4] = [0x28, 0xb5, 0x2f, 0xfd];
15const BYTES_NEEDED: usize = 6;
16
17#[derive(Debug, Clone, Copy)]
18enum Algorithm {
19    Gzip,
20    Bzip2,
21    Xz,
22    Zstd,
23}
24
25impl Algorithm {
26    fn from_magic(magic: &[u8]) -> Option<Self> {
27        if magic.starts_with(&GZIP_MAGIC) {
28            Some(Self::Gzip)
29        } else if magic.starts_with(&BZIP2_MAGIC) {
30            Some(Self::Bzip2)
31        } else if magic.starts_with(&XZ_MAGIC) {
32            Some(Self::Xz)
33        } else if magic.starts_with(&ZSTD_MAGIC) {
34            Some(Self::Zstd)
35        } else {
36            None
37        }
38    }
39}
40
41#[derive(Clone)]
42pub struct SmallBuf<const N: usize> {
43    data_len: u8,
44    buf: [u8; N],
45}
46
47impl<const N: usize> AsRef<[u8]> for SmallBuf<N> {
48    fn as_ref(&self) -> &[u8] {
49        &self.buf[..self.data_len as usize]
50    }
51}
52
53impl<const N: usize> std::ops::Deref for SmallBuf<N> {
54    type Target = [u8];
55
56    fn deref(&self) -> &Self::Target {
57        &self.buf[..self.data_len as usize]
58    }
59}
60
61type WithPreexistingBuffer<R> = tokio::io::Chain<io::Cursor<SmallBuf<BYTES_NEEDED>>, R>;
62
63#[pin_project(project = DecompressedReaderProj)]
64pub enum DecompressedReader<R> {
65    Unknown(#[pin] WithPreexistingBuffer<R>),
66    Gzip(#[pin] GzipDecoder<WithPreexistingBuffer<R>>),
67    Bzip(#[pin] BzDecoder<WithPreexistingBuffer<R>>),
68    Xz(#[pin] XzDecoder<WithPreexistingBuffer<R>>),
69    Zstd(#[pin] ZstdDecoder<WithPreexistingBuffer<R>>),
70}
71
72impl<R: AsyncRead + Unpin> DecompressedReader<R> {
73    /// Reads up to `BYTES_NEEDED` bytes from the underlying reader,
74    /// and returns those Bytes as a SmallBuf, as well as a reader allowing to read all data
75    async fn buffer_magic(
76        mut r: R,
77    ) -> std::io::Result<(SmallBuf<BYTES_NEEDED>, WithPreexistingBuffer<R>)> {
78        let mut buf = [0; BYTES_NEEDED];
79        let mut buf_filled = 0;
80
81        while buf_filled < BYTES_NEEDED {
82            let bytes_read = r.read(&mut buf).await?;
83            // EOF while filling buf
84            if bytes_read == 0 {
85                tracing::trace!("got EOF while filling buffer");
86                break;
87            }
88            buf_filled += bytes_read;
89        }
90
91        let buf = SmallBuf {
92            data_len: buf_filled as u8,
93            buf,
94        };
95
96        Ok((buf.clone(), Cursor::new(buf).chain(r)))
97    }
98}
99
100impl<R: AsyncBufRead + Unpin> DecompressedReader<R> {
101    /// Checks the passed reader for a suitable compression magic to be present,
102    /// pulling up to 6 bytes of data.
103    /// If there is, returns a reader which allows reading decompressed data.
104    /// Else, returns a reader that reads the uncompressed/undetected data (including the up to 6 bytes).
105    pub async fn new(reader: R) -> std::io::Result<Self> {
106        let (buffer, r) = Self::buffer_magic(reader).await?;
107
108        // r.buffer is guaranteed to have at least BYTES_NEEDED bytes if not EOF before.
109        Ok(match Algorithm::from_magic(&buffer) {
110            Some(Algorithm::Gzip) => Self::Gzip(GzipDecoder::new(r)),
111            Some(Algorithm::Bzip2) => Self::Bzip(BzDecoder::new(r)),
112            Some(Algorithm::Xz) => Self::Xz(XzDecoder::new(r)),
113            Some(Algorithm::Zstd) => Self::Zstd(ZstdDecoder::new(r)),
114            None => Self::Unknown(r),
115        })
116    }
117
118    /// The decoders don't implement [AsyncBufRead], so we cannot implement it for [DecompressedReader].
119    /// However, in the unknown case we only wrap R (and the buffer we read so far), they are [AsyncBufRead].
120    /// This provides a way to get a mutable reference to it.
121    #[allow(unused)]
122    pub fn as_unknown_inner(&mut self) -> Option<&mut WithPreexistingBuffer<R>> {
123        if let Self::Unknown(r) = self {
124            Some(r)
125        } else {
126            None
127        }
128    }
129}
130
131impl<R> AsyncRead for DecompressedReader<R>
132where
133    R: AsyncBufRead,
134{
135    fn poll_read(
136        self: Pin<&mut Self>,
137        cx: &mut Context<'_>,
138        buf: &mut ReadBuf<'_>,
139    ) -> Poll<io::Result<()>> {
140        match self.project() {
141            DecompressedReaderProj::Unknown(inner) => inner.poll_read(cx, buf),
142            DecompressedReaderProj::Gzip(inner) => inner.poll_read(cx, buf),
143            DecompressedReaderProj::Bzip(inner) => inner.poll_read(cx, buf),
144            DecompressedReaderProj::Xz(inner) => inner.poll_read(cx, buf),
145            DecompressedReaderProj::Zstd(inner) => inner.poll_read(cx, buf),
146        }
147    }
148}
149
150#[cfg(test)]
151mod tests {
152    use std::path::Path;
153
154    use async_compression::tokio::bufread::GzipEncoder;
155    use futures::TryStreamExt;
156    use rstest::rstest;
157    use tokio::io::{AsyncBufReadExt, AsyncReadExt, BufReader};
158    use tokio_tar::Archive;
159
160    use super::*;
161
162    #[tokio::test]
163    async fn gzip() {
164        let data = b"abcdefghijk";
165        let mut enc = GzipEncoder::new(&data[..]);
166        let mut gzipped = vec![];
167        enc.read_to_end(&mut gzipped).await.unwrap();
168
169        let mut reader = DecompressedReader::new(BufReader::new(&gzipped[..]))
170            .await
171            .expect("new to succeed");
172        let mut round_tripped = vec![];
173        reader.read_to_end(&mut round_tripped).await.unwrap();
174
175        assert_eq!(data[..], round_tripped[..]);
176    }
177
178    #[rstest]
179    #[case::gzip(include_bytes!("tests/blob.tar.gz"))]
180    #[case::bzip2(include_bytes!("tests/blob.tar.bz2"))]
181    #[case::xz(include_bytes!("tests/blob.tar.xz"))]
182    #[case::zstd(include_bytes!("tests/blob.tar.zst"))]
183    #[tokio::test]
184    async fn compressed_tar(#[case] data: &[u8]) {
185        let reader = DecompressedReader::new(BufReader::new(data))
186            .await
187            .expect("new to succeed");
188        let mut archive = Archive::new(reader);
189        let mut entries: Vec<_> = archive.entries().unwrap().try_collect().await.unwrap();
190
191        assert_eq!(entries.len(), 1);
192        assert_eq!(entries[0].path().unwrap().as_ref(), Path::new("empty"));
193        let mut data = String::new();
194        entries[0].read_to_string(&mut data).await.unwrap();
195        assert_eq!(data, "");
196    }
197
198    #[tokio::test]
199    async fn unknown() {
200        let data = b"abcdefghijk";
201        let mut reader = DecompressedReader::new(BufReader::new(Cursor::new(data)))
202            .await
203            .expect("new to succeed");
204
205        // this is expected to implement AsyncBufRead, so we can do the following:
206        let inner_reader = reader.as_unknown_inner().expect("to be some");
207        let _ = inner_reader.fill_buf().await;
208
209        // ... but we should also be able to just read from the outer:
210        let mut buf = Vec::new();
211        reader
212            .read_to_end(&mut buf)
213            .await
214            .expect("read_to_end to not fail");
215
216        assert_eq!(data[..], buf[..], "read data should match");
217    }
218}