1mod file;
50mod stream;
51
52use std::io;
53use std::ops::Bound;
54use std::pin::Pin;
55use std::task::{Context, Poll};
56
57use axum::http::StatusCode;
58use axum::response::{IntoResponse, Response};
59use axum_extra::TypedHeader;
60use axum_extra::headers::{Range, ContentRange, ContentLength, AcceptRanges};
61use tokio::io::{AsyncRead, AsyncSeek};
62
63pub use file::KnownSize;
64pub use stream::RangedStream;
65
66pub trait AsyncSeekStart {
68 fn start_seek(self: Pin<&mut Self>, position: u64) -> io::Result<()>;
70
71 fn poll_complete(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>>;
73}
74
75impl<T: AsyncSeek> AsyncSeekStart for T {
76 fn start_seek(self: Pin<&mut Self>, position: u64) -> io::Result<()> {
77 AsyncSeek::start_seek(self, io::SeekFrom::Start(position))
78 }
79
80 fn poll_complete(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
81 AsyncSeek::poll_complete(self, cx).map_ok(|_| ())
82 }
83}
84
85pub trait RangeBody: AsyncRead + AsyncSeekStart {
87 fn byte_size(&self) -> u64;
92}
93
94pub struct Ranged<B: RangeBody + Send + 'static> {
96 range: Option<Range>,
97 body: B,
98}
99
100impl<B: RangeBody + Send + 'static> Ranged<B> {
101 pub fn new(range: Option<Range>, body: B) -> Self {
104 Ranged { range, body }
105 }
106
107 pub fn try_respond(self) -> Result<RangedResponse<B>, RangeNotSatisfiable> {
111 let total_bytes = self.body.byte_size();
112
113 let range = self.range.and_then(|range| {
117 range.satisfiable_ranges(total_bytes).nth(0)
118 });
119
120 let seek_start = match range {
122 Some((Bound::Included(seek_start), _)) => seek_start,
123 _ => 0,
124 };
125
126 let seek_end_excl = match range {
127 Some((_, Bound::Included(end))) => end + 1,
129 _ => total_bytes,
130 };
131
132 let seek_start_beyond_seek_end = seek_start > seek_end_excl;
134 let seek_end_beyond_file_range = seek_end_excl > total_bytes;
135 let zero_length_range = seek_start == seek_end_excl;
137
138 if seek_start_beyond_seek_end || seek_end_beyond_file_range || zero_length_range {
139 let content_range = ContentRange::unsatisfied_bytes(total_bytes);
140 return Err(RangeNotSatisfiable(content_range));
141 }
142
143 let content_range = range.map(|_| {
145 ContentRange::bytes(seek_start..seek_end_excl, total_bytes)
146 .expect("ContentRange::bytes cannot panic in this usage")
147 });
148
149 let content_length = ContentLength(seek_end_excl - seek_start);
150
151 let stream = RangedStream::new(self.body, seek_start, content_length.0);
152
153 Ok(RangedResponse {
154 content_range,
155 content_length,
156 stream,
157 })
158 }
159}
160
161impl<B: RangeBody + Send + 'static> IntoResponse for Ranged<B> {
162 fn into_response(self) -> Response {
163 self.try_respond().into_response()
164 }
165}
166
167#[derive(Debug, Clone)]
169pub struct RangeNotSatisfiable(pub ContentRange);
170
171impl IntoResponse for RangeNotSatisfiable {
172 fn into_response(self) -> Response {
173 let status = StatusCode::RANGE_NOT_SATISFIABLE;
174 let header = TypedHeader(self.0);
175 (status, header, ()).into_response()
176 }
177}
178
179pub struct RangedResponse<B> {
181 pub content_range: Option<ContentRange>,
182 pub content_length: ContentLength,
183 pub stream: RangedStream<B>,
184}
185
186impl<B: RangeBody + Send + 'static> IntoResponse for RangedResponse<B> {
187 fn into_response(self) -> Response {
188 let content_range = self.content_range.map(TypedHeader);
189 let content_length = TypedHeader(self.content_length);
190 let accept_ranges = TypedHeader(AcceptRanges::bytes());
191 let stream = self.stream;
192
193 let status = match content_range {
194 Some(_) => StatusCode::PARTIAL_CONTENT,
195 None => StatusCode::OK,
196 };
197
198 (status, content_range, content_length, accept_ranges, stream).into_response()
199 }
200}
201
202#[cfg(test)]
203mod tests {
204 use std::io;
205
206 use axum::http::HeaderValue;
207 use axum_extra::headers::{ContentRange, Header, Range};
208 use bytes::Bytes;
209 use futures::{pin_mut, Stream, StreamExt};
210 use tokio::fs::File;
211
212 use crate::Ranged;
213 use crate::KnownSize;
214
215 async fn collect_stream(stream: impl Stream<Item = io::Result<Bytes>>) -> String {
216 let mut string = String::new();
217 pin_mut!(stream);
218 while let Some(chunk) = stream.next().await.transpose().unwrap() {
219 string += std::str::from_utf8(&chunk).unwrap();
220 }
221 string
222 }
223
224 fn range(header: &str) -> Option<Range> {
225 let val = HeaderValue::from_str(header).unwrap();
226 Some(Range::decode(&mut [val].iter()).unwrap())
227 }
228
229 async fn body() -> KnownSize<File> {
230 let file = File::open("test/fixture.txt").await.unwrap();
231 KnownSize::file(file).await.unwrap()
232 }
233
234 #[tokio::test]
235 async fn test_full_response() {
236 let ranged = Ranged::new(None, body().await);
237
238 let response = ranged.try_respond().expect("try_respond should return Ok");
239
240 assert_eq!(54, response.content_length.0);
241 assert!(response.content_range.is_none());
242 assert_eq!("Hello world this is a file to test range requests on!\n",
243 &collect_stream(response.stream).await);
244 }
245
246 #[tokio::test]
247 async fn test_partial_response_1() {
248 let ranged = Ranged::new(range("bytes=0-29"), body().await);
249
250 let response = ranged.try_respond().expect("try_respond should return Ok");
251
252 assert_eq!(30, response.content_length.0);
253
254 let expected_content_range = ContentRange::bytes(0..30, 54).unwrap();
255 assert_eq!(Some(expected_content_range), response.content_range);
256
257 assert_eq!("Hello world this is a file to ",
258 &collect_stream(response.stream).await);
259 }
260
261 #[tokio::test]
262 async fn test_partial_response_2() {
263 let ranged = Ranged::new(range("bytes=30-53"), body().await);
264
265 let response = ranged.try_respond().expect("try_respond should return Ok");
266
267 assert_eq!(24, response.content_length.0);
268
269 let expected_content_range = ContentRange::bytes(30..54, 54).unwrap();
270 assert_eq!(Some(expected_content_range), response.content_range);
271
272 assert_eq!("test range requests on!\n",
273 &collect_stream(response.stream).await);
274 }
275
276 #[tokio::test]
277 async fn test_unbounded_start_response() {
278 let ranged = Ranged::new(range("bytes=-20"), body().await);
281
282 let response = ranged.try_respond().expect("try_respond should return Ok");
283
284 assert_eq!(20, response.content_length.0);
285
286 let expected_content_range = ContentRange::bytes(34..54, 54).unwrap();
287 assert_eq!(Some(expected_content_range), response.content_range);
288
289 assert_eq!(" range requests on!\n",
290 &collect_stream(response.stream).await);
291 }
292
293 #[tokio::test]
294 async fn test_unbounded_end_response() {
295 let ranged = Ranged::new(range("bytes=40-"), body().await);
296
297 let response = ranged.try_respond().expect("try_respond should return Ok");
298
299 assert_eq!(14, response.content_length.0);
300
301 let expected_content_range = ContentRange::bytes(40..54, 54).unwrap();
302 assert_eq!(Some(expected_content_range), response.content_range);
303
304 assert_eq!(" requests on!\n",
305 &collect_stream(response.stream).await);
306 }
307
308 #[tokio::test]
309 async fn test_one_byte_response() {
310 let ranged = Ranged::new(range("bytes=30-30"), body().await);
311
312 let response = ranged.try_respond().expect("try_respond should return Ok");
313
314 assert_eq!(1, response.content_length.0);
315
316 let expected_content_range = ContentRange::bytes(30..31, 54).unwrap();
317 assert_eq!(Some(expected_content_range), response.content_range);
318
319 assert_eq!("t",
320 &collect_stream(response.stream).await);
321 }
322
323 #[tokio::test]
324 async fn test_invalid_range() {
325 let ranged = Ranged::new(range("bytes=30-29"), body().await);
326
327 let err = ranged.try_respond().err().expect("try_respond should return Err");
328
329 let expected_content_range = ContentRange::unsatisfied_bytes(54);
330 assert_eq!(expected_content_range, err.0)
331 }
332
333 #[tokio::test]
334 async fn test_range_end_exceed_length() {
335 let ranged = Ranged::new(range("bytes=30-99"), body().await);
336
337 let err = ranged.try_respond().err().expect("try_respond should return Err");
338
339 let expected_content_range = ContentRange::unsatisfied_bytes(54);
340 assert_eq!(expected_content_range, err.0)
341 }
342
343 #[tokio::test]
344 async fn test_range_start_exceed_length() {
345 let ranged = Ranged::new(range("bytes=99-"), body().await);
346
347 let err = ranged.try_respond().err().expect("try_respond should return Err");
348
349 let expected_content_range = ContentRange::unsatisfied_bytes(54);
350 assert_eq!(expected_content_range, err.0)
351 }
352}