axum_range/
file.rs

1use std::io;
2use std::pin::Pin;
3use std::task::{Context, Poll};
4
5use pin_project::pin_project;
6use tokio::io::{ReadBuf, AsyncRead, AsyncSeek, AsyncSeekExt};
7
8use crate::{RangeBody, AsyncSeekStart};
9
10/// Implements [`RangeBody`] for any [`AsyncRead`] and [`AsyncSeekStart`], constructed with a fixed byte size.
11#[pin_project]
12pub struct KnownSize<B: AsyncRead + AsyncSeekStart> {
13    byte_size: u64,
14    #[pin]
15    body: B,
16}
17
18impl KnownSize<tokio::fs::File> {
19    /// Calls [`tokio::fs::File::metadata`] to determine file size.
20    pub async fn file(file: tokio::fs::File) -> io::Result<KnownSize<tokio::fs::File>> {
21        let byte_size = file.metadata().await?.len();
22        Ok(KnownSize { byte_size, body: file })
23    }
24}
25
26impl<B: AsyncRead + AsyncSeekStart> KnownSize<B> {
27    /// Construct a [`KnownSize`] instance with a byte size supplied manually.
28    pub fn sized(body: B, byte_size: u64) -> Self {
29        KnownSize { byte_size, body }
30    }
31}
32
33impl<B: AsyncRead + AsyncSeek + Unpin> KnownSize<B> {
34    /// Uses `seek` to determine size by seeking to the end and getting stream position.
35    pub async fn seek(mut body: B) -> io::Result<KnownSize<B>> {
36        let byte_size = Pin::new(&mut body).seek(io::SeekFrom::End(0)).await?;
37        Ok(KnownSize { byte_size, body })
38    }
39}
40
41impl<B: AsyncRead + AsyncSeekStart> AsyncRead for KnownSize<B> {
42    fn poll_read(
43        self: Pin<&mut Self>,
44        cx: &mut Context<'_>,
45        buf: &mut ReadBuf<'_>,
46    ) -> Poll<io::Result<()>> {
47        let this = self.project();
48        this.body.poll_read(cx, buf)
49    }
50}
51
52impl<B: AsyncRead + AsyncSeekStart> AsyncSeekStart for KnownSize<B> {
53    fn start_seek(
54        self: Pin<&mut Self>,
55        position: u64,
56    ) -> io::Result<()> {
57        let this = self.project();
58        this.body.start_seek(position)
59    }
60
61    fn poll_complete(
62        self: Pin<&mut Self>,
63        cx: &mut Context<'_>,
64    ) -> Poll<io::Result<()>> {
65        let this = self.project();
66        this.body.poll_complete(cx)
67    }
68}
69
70impl<B: AsyncRead + AsyncSeekStart> RangeBody for KnownSize<B> {
71    fn byte_size(&self) -> u64 {
72        self.byte_size
73    }
74}
75
76#[cfg(test)]
77mod tests {
78    use tokio::fs::File;
79    use crate::RangeBody;
80
81    use super::KnownSize;
82
83    #[tokio::test]
84    async fn test_file_size() {
85        let file = File::open("test/fixture.txt").await.unwrap();
86        let known_size = KnownSize::file(file).await.unwrap();
87        assert_eq!(54, known_size.byte_size());
88    }
89
90    #[tokio::test]
91    async fn test_seek_size() {
92        let file = File::open("test/fixture.txt").await.unwrap();
93        let known_size = KnownSize::file(file).await.unwrap();
94        assert_eq!(54, known_size.byte_size());
95    }
96}