1use std::{io, mem};
2use std::pin::Pin;
3use std::task::{Context, Poll};
4
5use axum::response::{Response, IntoResponse};
6use bytes::{Bytes, BytesMut};
7use http_body::{Body, SizeHint, Frame};
8use futures::Stream;
9use pin_project::pin_project;
10use tokio::io::ReadBuf;
11
12use crate::RangeBody;
13
14const IO_BUFFER_SIZE: usize = 64 * 1024;
15
16#[pin_project]
18pub struct RangedStream<B> {
19 state: StreamState,
20 length: u64,
21 #[pin]
22 body: B,
23}
24
25impl<B: RangeBody + Send + 'static> RangedStream<B> {
26 pub(crate) fn new(body: B, start: u64, length: u64) -> Self {
27 RangedStream {
28 state: StreamState::Seek { start },
29 length,
30 body,
31 }
32 }
33}
34
35#[derive(Debug)]
36enum StreamState {
37 Seek { start: u64 },
38 Seeking { remaining: u64 },
39 Reading { buffer: BytesMut, remaining: u64 },
40}
41
42impl<B: RangeBody + Send + 'static> IntoResponse for RangedStream<B> {
43 fn into_response(self) -> Response {
44 Response::new(axum::body::Body::new(self))
45 }
46}
47
48impl<B: RangeBody> Body for RangedStream<B> {
49 type Data = Bytes;
50 type Error = io::Error;
51
52 fn size_hint(&self) -> SizeHint {
53 SizeHint::with_exact(self.length)
54 }
55
56 fn poll_frame(self: Pin<&mut Self>, cx: &mut Context<'_>)
57 -> Poll<Option<io::Result<Frame<Bytes>>>>
58 {
59 self.poll_next(cx).map(|item| item.map(|result| result.map(Frame::data)))
60 }
61}
62
63impl<B: RangeBody> Stream for RangedStream<B> {
64 type Item = io::Result<Bytes>;
65
66 fn poll_next(
67 self: Pin<&mut Self>,
68 cx: &mut Context<'_>
69 ) -> Poll<Option<io::Result<Bytes>>> {
70 let mut this = self.project();
71
72 if let StreamState::Seek { start } = *this.state {
73 match this.body.as_mut().start_seek(start) {
74 Err(e) => { return Poll::Ready(Some(Err(e))); }
75 Ok(()) => {
76 let remaining = *this.length;
77 *this.state = StreamState::Seeking { remaining };
78 }
79 }
80 }
81
82 if let StreamState::Seeking { remaining } = *this.state {
83 match this.body.as_mut().poll_complete(cx) {
84 Poll::Pending => { return Poll::Pending; }
85 Poll::Ready(Err(e)) => { return Poll::Ready(Some(Err(e))); }
86 Poll::Ready(Ok(())) => {
87 let buffer = allocate_buffer();
88 *this.state = StreamState::Reading { buffer, remaining };
89 }
90 }
91 }
92
93 if let StreamState::Reading { buffer, remaining } = this.state {
94 let uninit = buffer.spare_capacity_mut();
95
96 let nbytes = std::cmp::min(
99 uninit.len(),
100 usize::try_from(*remaining).unwrap_or(usize::MAX),
101 );
102
103 let mut read_buf = ReadBuf::uninit(&mut uninit[0..nbytes]);
104
105 match this.body.as_mut().poll_read(cx, &mut read_buf) {
106 Poll::Pending => { return Poll::Pending; }
107 Poll::Ready(Err(e)) => { return Poll::Ready(Some(Err(e))); }
108 Poll::Ready(Ok(())) => {
109 match read_buf.filled().len() {
110 0 => { return Poll::Ready(None); }
111 n => {
112 unsafe { buffer.set_len(buffer.len() + n); }
116
117 let chunk = mem::replace(buffer, allocate_buffer());
119
120 *remaining -= u64::try_from(n).unwrap();
125
126 return Poll::Ready(Some(Ok(chunk.freeze())));
128 }
129 }
130 }
131 }
132 }
133
134 unreachable!();
135 }
136}
137
138fn allocate_buffer() -> BytesMut {
139 BytesMut::with_capacity(IO_BUFFER_SIZE)
140}