async_compression/codec/zstd/
decoder.rs

1use std::io;
2use std::io::Result;
3
4use crate::{codec::Decode, unshared::Unshared, util::PartialBuffer};
5use libzstd::stream::raw::{Decoder, Operation};
6
7#[derive(Debug)]
8pub struct ZstdDecoder {
9    decoder: Unshared<Decoder<'static>>,
10}
11
12impl ZstdDecoder {
13    pub(crate) fn new() -> Self {
14        Self {
15            decoder: Unshared::new(Decoder::new().unwrap()),
16        }
17    }
18
19    pub(crate) fn new_with_params(params: &[crate::zstd::DParameter]) -> Self {
20        let mut decoder = Decoder::new().unwrap();
21        for param in params {
22            decoder.set_parameter(param.as_zstd()).unwrap();
23        }
24        Self {
25            decoder: Unshared::new(decoder),
26        }
27    }
28
29    pub(crate) fn new_with_dict(dictionary: &[u8]) -> io::Result<Self> {
30        let mut decoder = Decoder::with_dictionary(dictionary)?;
31        Ok(Self {
32            decoder: Unshared::new(decoder),
33        })
34    }
35}
36
37impl Decode for ZstdDecoder {
38    fn reinit(&mut self) -> Result<()> {
39        self.decoder.get_mut().reinit()?;
40        Ok(())
41    }
42
43    fn decode(
44        &mut self,
45        input: &mut PartialBuffer<impl AsRef<[u8]>>,
46        output: &mut PartialBuffer<impl AsRef<[u8]> + AsMut<[u8]>>,
47    ) -> Result<bool> {
48        let status = self
49            .decoder
50            .get_mut()
51            .run_on_buffers(input.unwritten(), output.unwritten_mut())?;
52        input.advance(status.bytes_read);
53        output.advance(status.bytes_written);
54        Ok(status.remaining == 0)
55    }
56
57    fn flush(
58        &mut self,
59        output: &mut PartialBuffer<impl AsRef<[u8]> + AsMut<[u8]>>,
60    ) -> Result<bool> {
61        let mut out_buf = zstd_safe::OutBuffer::around(output.unwritten_mut());
62        let bytes_left = self.decoder.get_mut().flush(&mut out_buf)?;
63        let len = out_buf.as_slice().len();
64        output.advance(len);
65        Ok(bytes_left == 0)
66    }
67
68    fn finish(
69        &mut self,
70        output: &mut PartialBuffer<impl AsRef<[u8]> + AsMut<[u8]>>,
71    ) -> Result<bool> {
72        let mut out_buf = zstd_safe::OutBuffer::around(output.unwritten_mut());
73        let bytes_left = self.decoder.get_mut().finish(&mut out_buf, true)?;
74        let len = out_buf.as_slice().len();
75        output.advance(len);
76        Ok(bytes_left == 0)
77    }
78}