snix_castore/
hashing_reader.rs

1use pin_project_lite::pin_project;
2use tokio::io::AsyncRead;
3
4pin_project! {
5    /// Wraps an existing AsyncRead, and allows querying for the digest of all
6    /// data read "through" it.
7    /// The hash function is configurable by type parameter.
8    pub struct HashingReader<R, H>
9    where
10        R: AsyncRead,
11        H: digest::Digest,
12    {
13        #[pin]
14        inner: R,
15        hasher: H,
16    }
17}
18
19pub type B3HashingReader<R> = HashingReader<R, blake3::Hasher>;
20
21impl<R, H> HashingReader<R, H>
22where
23    R: AsyncRead,
24    H: digest::Digest,
25{
26    pub fn from(r: R) -> Self {
27        Self {
28            inner: r,
29            hasher: H::new(),
30        }
31    }
32
33    /// Return the digest.
34    pub fn digest(self) -> digest::Output<H> {
35        self.hasher.finalize()
36    }
37}
38
39impl<R, H> tokio::io::AsyncRead for HashingReader<R, H>
40where
41    R: AsyncRead,
42    H: digest::Digest,
43{
44    fn poll_read(
45        self: std::pin::Pin<&mut Self>,
46        cx: &mut std::task::Context<'_>,
47        buf: &mut tokio::io::ReadBuf<'_>,
48    ) -> std::task::Poll<std::io::Result<()>> {
49        let buf_filled_len_before = buf.filled().len();
50
51        let this = self.project();
52        let ret = this.inner.poll_read(cx, buf);
53
54        // write everything new filled into the hasher.
55        this.hasher.update(&buf.filled()[buf_filled_len_before..]);
56
57        ret
58    }
59}
60
61#[cfg(test)]
62mod tests {
63    use std::io::Cursor;
64
65    use rstest::rstest;
66
67    use crate::fixtures::BLOB_A;
68    use crate::fixtures::BLOB_A_DIGEST;
69    use crate::fixtures::BLOB_B;
70    use crate::fixtures::BLOB_B_DIGEST;
71    use crate::fixtures::EMPTY_BLOB_DIGEST;
72    use crate::{B3Digest, B3HashingReader};
73
74    #[rstest]
75    #[case::blob_a(&BLOB_A, &BLOB_A_DIGEST)]
76    #[case::blob_b(&BLOB_B, &BLOB_B_DIGEST)]
77    #[case::empty_blob(&[], &EMPTY_BLOB_DIGEST)]
78    #[tokio::test]
79    async fn test_b3_hashing_reader(#[case] data: &[u8], #[case] b3_digest: &B3Digest) {
80        let r = Cursor::new(data);
81        let mut hr = B3HashingReader::from(r);
82
83        tokio::io::copy(&mut hr, &mut tokio::io::sink())
84            .await
85            .expect("read must succeed");
86
87        assert_eq!(*b3_digest, hr.digest().into());
88    }
89}