axum_range/
stream.rs

1use std::{io, mem};
2use std::pin::Pin;
3use std::task::{Context, Poll};
4
5use axum::response::{Response, IntoResponse};
6use bytes::{Bytes, BytesMut};
7use http_body::{Body, SizeHint, Frame};
8use futures::Stream;
9use pin_project::pin_project;
10use tokio::io::ReadBuf;
11
12use crate::RangeBody;
13
14const IO_BUFFER_SIZE: usize = 64 * 1024;
15
16/// Response body stream. Implements [`Stream`], [`Body`], and [`IntoResponse`].
17#[pin_project]
18pub struct RangedStream<B> {
19    state: StreamState,
20    length: u64,
21    #[pin]
22    body: B,
23}
24
25impl<B: RangeBody + Send + 'static> RangedStream<B> {
26    pub(crate) fn new(body: B, start: u64, length: u64) -> Self {
27        RangedStream {
28            state: StreamState::Seek { start },
29            length,
30            body,
31        }
32    }
33}
34
35#[derive(Debug)]
36enum StreamState {
37    Seek { start: u64 },
38    Seeking { remaining: u64 },
39    Reading { buffer: BytesMut, remaining: u64 },
40}
41
42impl<B: RangeBody + Send + 'static> IntoResponse for RangedStream<B> {
43    fn into_response(self) -> Response {
44        Response::new(axum::body::Body::new(self))
45    }
46}
47
48impl<B: RangeBody> Body for RangedStream<B> {
49    type Data = Bytes;
50    type Error = io::Error;
51
52    fn size_hint(&self) -> SizeHint {
53        SizeHint::with_exact(self.length)
54    }
55
56    fn poll_frame(self: Pin<&mut Self>, cx: &mut Context<'_>)
57        -> Poll<Option<io::Result<Frame<Bytes>>>>
58    {
59        self.poll_next(cx).map(|item| item.map(|result| result.map(Frame::data)))
60    }
61}
62
63impl<B: RangeBody> Stream for RangedStream<B> {
64    type Item = io::Result<Bytes>;
65
66    fn poll_next(
67        self: Pin<&mut Self>,
68        cx: &mut Context<'_>
69    ) -> Poll<Option<io::Result<Bytes>>> {
70        let mut this = self.project();
71
72        if let StreamState::Seek { start } = *this.state {
73            match this.body.as_mut().start_seek(start) {
74                Err(e) => { return Poll::Ready(Some(Err(e))); }
75                Ok(()) => {
76                    let remaining = *this.length;
77                    *this.state = StreamState::Seeking { remaining };
78                }
79            }
80        }
81
82        if let StreamState::Seeking { remaining } = *this.state {
83            match this.body.as_mut().poll_complete(cx) {
84                Poll::Pending => { return Poll::Pending; }
85                Poll::Ready(Err(e)) => { return Poll::Ready(Some(Err(e))); }
86                Poll::Ready(Ok(())) => {
87                    let buffer = allocate_buffer();
88                    *this.state = StreamState::Reading { buffer, remaining };
89                }
90            }
91        }
92
93        if let StreamState::Reading { buffer, remaining } = this.state {
94            let uninit = buffer.spare_capacity_mut();
95
96            // calculate max number of bytes to read in this iteration, the
97            // smaller of the buffer size and the number of bytes remaining
98            let nbytes = std::cmp::min(
99                uninit.len(),
100                usize::try_from(*remaining).unwrap_or(usize::MAX),
101            );
102
103            let mut read_buf = ReadBuf::uninit(&mut uninit[0..nbytes]);
104
105            match this.body.as_mut().poll_read(cx, &mut read_buf) {
106                Poll::Pending => { return Poll::Pending; }
107                Poll::Ready(Err(e)) => { return Poll::Ready(Some(Err(e))); }
108                Poll::Ready(Ok(())) => {
109                    match read_buf.filled().len() {
110                        0 => { return Poll::Ready(None); }
111                        n => {
112                            // SAFETY: poll_read has filled the buffer with `n`
113                            // additional bytes. `buffer.len` should always be
114                            // 0 here, but include it for rigorous correctness
115                            unsafe { buffer.set_len(buffer.len() + n); }
116
117                            // replace state buffer and take this one to return
118                            let chunk = mem::replace(buffer, allocate_buffer());
119
120                            // subtract the number of bytes we just read from
121                            // state.remaining, this usize->u64 conversion is
122                            // guaranteed to always succeed, because n cannot be
123                            // larger than remaining due to the cmp::min above
124                            *remaining -= u64::try_from(n).unwrap();
125
126                            // return this chunk
127                            return Poll::Ready(Some(Ok(chunk.freeze())));
128                        }
129                    }
130                }
131            }
132        }
133
134        unreachable!();
135    }
136}
137
138fn allocate_buffer() -> BytesMut {
139    BytesMut::with_capacity(IO_BUFFER_SIZE)
140}