snix_store/
decompression.rs1use std::{
2 io::{self, Cursor},
3 pin::Pin,
4 task::{Context, Poll},
5};
6
7use async_compression::tokio::bufread::{BzDecoder, GzipDecoder, XzDecoder, ZstdDecoder};
8use pin_project::pin_project;
9use tokio::io::{AsyncBufRead, AsyncRead, AsyncReadExt, ReadBuf};
10
11const GZIP_MAGIC: [u8; 2] = [0x1f, 0x8b];
12const BZIP2_MAGIC: [u8; 3] = *b"BZh";
13const XZ_MAGIC: [u8; 6] = [0xfd, 0x37, 0x7a, 0x58, 0x5a, 0x00];
14const ZSTD_MAGIC: [u8; 4] = [0x28, 0xb5, 0x2f, 0xfd];
15const BYTES_NEEDED: usize = 6;
16
17#[derive(Debug, Clone, Copy)]
18enum Algorithm {
19 Gzip,
20 Bzip2,
21 Xz,
22 Zstd,
23}
24
25impl Algorithm {
26 fn from_magic(magic: &[u8]) -> Option<Self> {
27 if magic.starts_with(&GZIP_MAGIC) {
28 Some(Self::Gzip)
29 } else if magic.starts_with(&BZIP2_MAGIC) {
30 Some(Self::Bzip2)
31 } else if magic.starts_with(&XZ_MAGIC) {
32 Some(Self::Xz)
33 } else if magic.starts_with(&ZSTD_MAGIC) {
34 Some(Self::Zstd)
35 } else {
36 None
37 }
38 }
39}
40
41#[derive(Clone)]
42pub struct SmallBuf<const N: usize> {
43 data_len: u8,
44 buf: [u8; N],
45}
46
47impl<const N: usize> AsRef<[u8]> for SmallBuf<N> {
48 fn as_ref(&self) -> &[u8] {
49 &self.buf[..self.data_len as usize]
50 }
51}
52
53impl<const N: usize> std::ops::Deref for SmallBuf<N> {
54 type Target = [u8];
55
56 fn deref(&self) -> &Self::Target {
57 &self.buf[..self.data_len as usize]
58 }
59}
60
61type WithPreexistingBuffer<R> = tokio::io::Chain<io::Cursor<SmallBuf<BYTES_NEEDED>>, R>;
62
63#[pin_project(project = DecompressedReaderProj)]
64pub enum DecompressedReader<R> {
65 Unknown(#[pin] WithPreexistingBuffer<R>),
66 Gzip(#[pin] GzipDecoder<WithPreexistingBuffer<R>>),
67 Bzip(#[pin] BzDecoder<WithPreexistingBuffer<R>>),
68 Xz(#[pin] XzDecoder<WithPreexistingBuffer<R>>),
69 Zstd(#[pin] ZstdDecoder<WithPreexistingBuffer<R>>),
70}
71
72impl<R: AsyncRead + Unpin> DecompressedReader<R> {
73 async fn buffer_magic(
76 mut r: R,
77 ) -> std::io::Result<(SmallBuf<BYTES_NEEDED>, WithPreexistingBuffer<R>)> {
78 let mut buf = [0; BYTES_NEEDED];
79 let mut buf_filled = 0;
80
81 while buf_filled < BYTES_NEEDED {
82 let bytes_read = r.read(&mut buf).await?;
83 if bytes_read == 0 {
85 tracing::trace!("got EOF while filling buffer");
86 break;
87 }
88 buf_filled += bytes_read;
89 }
90
91 let buf = SmallBuf {
92 data_len: buf_filled as u8,
93 buf,
94 };
95
96 Ok((buf.clone(), Cursor::new(buf).chain(r)))
97 }
98}
99
100impl<R: AsyncBufRead + Unpin> DecompressedReader<R> {
101 pub async fn new(reader: R) -> std::io::Result<Self> {
106 let (buffer, r) = Self::buffer_magic(reader).await?;
107
108 Ok(match Algorithm::from_magic(&buffer) {
110 Some(Algorithm::Gzip) => Self::Gzip(GzipDecoder::new(r)),
111 Some(Algorithm::Bzip2) => Self::Bzip(BzDecoder::new(r)),
112 Some(Algorithm::Xz) => Self::Xz(XzDecoder::new(r)),
113 Some(Algorithm::Zstd) => Self::Zstd(ZstdDecoder::new(r)),
114 None => Self::Unknown(r),
115 })
116 }
117
118 #[allow(unused)]
122 pub fn as_unknown_inner(&mut self) -> Option<&mut WithPreexistingBuffer<R>> {
123 if let Self::Unknown(r) = self {
124 Some(r)
125 } else {
126 None
127 }
128 }
129}
130
131impl<R> AsyncRead for DecompressedReader<R>
132where
133 R: AsyncBufRead,
134{
135 fn poll_read(
136 self: Pin<&mut Self>,
137 cx: &mut Context<'_>,
138 buf: &mut ReadBuf<'_>,
139 ) -> Poll<io::Result<()>> {
140 match self.project() {
141 DecompressedReaderProj::Unknown(inner) => inner.poll_read(cx, buf),
142 DecompressedReaderProj::Gzip(inner) => inner.poll_read(cx, buf),
143 DecompressedReaderProj::Bzip(inner) => inner.poll_read(cx, buf),
144 DecompressedReaderProj::Xz(inner) => inner.poll_read(cx, buf),
145 DecompressedReaderProj::Zstd(inner) => inner.poll_read(cx, buf),
146 }
147 }
148}
149
150#[cfg(test)]
151mod tests {
152 use std::path::Path;
153
154 use async_compression::tokio::bufread::GzipEncoder;
155 use futures::TryStreamExt;
156 use rstest::rstest;
157 use tokio::io::{AsyncBufReadExt, AsyncReadExt, BufReader};
158 use tokio_tar::Archive;
159
160 use super::*;
161
162 #[tokio::test]
163 async fn gzip() {
164 let data = b"abcdefghijk";
165 let mut enc = GzipEncoder::new(&data[..]);
166 let mut gzipped = vec![];
167 enc.read_to_end(&mut gzipped).await.unwrap();
168
169 let mut reader = DecompressedReader::new(BufReader::new(&gzipped[..]))
170 .await
171 .expect("new to succeed");
172 let mut round_tripped = vec![];
173 reader.read_to_end(&mut round_tripped).await.unwrap();
174
175 assert_eq!(data[..], round_tripped[..]);
176 }
177
178 #[rstest]
179 #[case::gzip(include_bytes!("tests/blob.tar.gz"))]
180 #[case::bzip2(include_bytes!("tests/blob.tar.bz2"))]
181 #[case::xz(include_bytes!("tests/blob.tar.xz"))]
182 #[case::zstd(include_bytes!("tests/blob.tar.zst"))]
183 #[tokio::test]
184 async fn compressed_tar(#[case] data: &[u8]) {
185 let reader = DecompressedReader::new(BufReader::new(data))
186 .await
187 .expect("new to succeed");
188 let mut archive = Archive::new(reader);
189 let mut entries: Vec<_> = archive.entries().unwrap().try_collect().await.unwrap();
190
191 assert_eq!(entries.len(), 1);
192 assert_eq!(entries[0].path().unwrap().as_ref(), Path::new("empty"));
193 let mut data = String::new();
194 entries[0].read_to_string(&mut data).await.unwrap();
195 assert_eq!(data, "");
196 }
197
198 #[tokio::test]
199 async fn unknown() {
200 let data = b"abcdefghijk";
201 let mut reader = DecompressedReader::new(BufReader::new(Cursor::new(data)))
202 .await
203 .expect("new to succeed");
204
205 let inner_reader = reader.as_unknown_inner().expect("to be some");
207 let _ = inner_reader.fill_buf().await;
208
209 let mut buf = Vec::new();
211 reader
212 .read_to_end(&mut buf)
213 .await
214 .expect("read_to_end to not fail");
215
216 assert_eq!(data[..], buf[..], "read data should match");
217 }
218}