axum_range/
lib.rs

1//! # axum-range
2//!
3//! HTTP range responses for [`axum`][1].
4//!
5//! Fully generic, supports any body implementing the [`RangeBody`] trait.
6//!
7//! Any type implementing both [`AsyncRead`] and [`AsyncSeekStart`] can be
8//! used the [`KnownSize`] adapter struct. There is also special cased support
9//! for [`tokio::fs::File`], see the [`KnownSize::file`] method.
10//!
11//! [`AsyncSeekStart`] is a trait defined by this crate which only allows
12//! seeking from the start of a file. It is automatically implemented for any
13//! type implementing [`AsyncSeek`].
14//!
15//! ```
16//! use axum::Router;
17//! use axum::routing::get;
18//! use axum_extra::TypedHeader;
19//! use axum_extra::headers::Range;
20//!
21//! use tokio::fs::File;
22//!
23//! use axum_range::Ranged;
24//! use axum_range::KnownSize;
25//!
26//! async fn file(range: Option<TypedHeader<Range>>) -> Ranged<KnownSize<File>> {
27//!     let file = File::open("The Sims 1 - The Complete Collection.rar").await.unwrap();
28//!     let body = KnownSize::file(file).await.unwrap();
29//!     let range = range.map(|TypedHeader(range)| range);
30//!     Ranged::new(range, body)
31//! }
32//!
33//! #[tokio::main]
34//! async fn main() {
35//!     // build our application with a single route
36//!     let _app = Router::<()>::new().route("/", get(file));
37//!
38//!     // run it with hyper on localhost:3000
39//!     #[cfg(feature = "run_server_in_example")]
40//!     axum::Server::bind(&"0.0.0.0:3000".parse().unwrap())
41//!        .serve(_app.into_make_service())
42//!        .await
43//!        .unwrap();
44//! }
45//! ```
46//!
47//! [1]: https://docs.rs/axum
48
49mod file;
50mod stream;
51
52use std::io;
53use std::ops::Bound;
54use std::pin::Pin;
55use std::task::{Context, Poll};
56
57use axum::http::StatusCode;
58use axum::response::{IntoResponse, Response};
59use axum_extra::TypedHeader;
60use axum_extra::headers::{Range, ContentRange, ContentLength, AcceptRanges};
61use tokio::io::{AsyncRead, AsyncSeek};
62
63pub use file::KnownSize;
64pub use stream::RangedStream;
65
66/// [`AsyncSeek`] narrowed to only allow seeking from start.
67pub trait AsyncSeekStart {
68    /// Same semantics as [`AsyncSeek::start_seek`], always passing position as the `SeekFrom::Start` variant.
69    fn start_seek(self: Pin<&mut Self>, position: u64) -> io::Result<()>;
70
71    /// Same semantics as [`AsyncSeek::poll_complete`], returning `()` instead of the new stream position.
72    fn poll_complete(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>>;
73}
74
75impl<T: AsyncSeek> AsyncSeekStart for T {
76    fn start_seek(self: Pin<&mut Self>, position: u64) -> io::Result<()> {
77        AsyncSeek::start_seek(self, io::SeekFrom::Start(position))
78    }
79
80    fn poll_complete(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
81        AsyncSeek::poll_complete(self, cx).map_ok(|_| ())
82    }
83}
84
85/// An [`AsyncRead`] and [`AsyncSeekStart`] with a fixed known byte size.
86pub trait RangeBody: AsyncRead + AsyncSeekStart {
87    /// The total size of the underlying file.
88    ///
89    /// This should not change for the lifetime of the object once queried.
90    /// Behaviour is not guaranteed if it does change.
91    fn byte_size(&self) -> u64;
92}
93
94/// The main responder type. Implements [`IntoResponse`].
95pub struct Ranged<B: RangeBody + Send + 'static> {
96    range: Option<Range>,
97    body: B,
98}
99
100impl<B: RangeBody + Send + 'static> Ranged<B> {
101    /// Construct a ranged response over any type implementing [`RangeBody`]
102    /// and an optional [`Range`] header.
103    pub fn new(range: Option<Range>, body: B) -> Self {
104        Ranged { range, body }
105    }
106
107    /// Responds to the request, returning headers and body as
108    /// [`RangedResponse`]. Returns [`RangeNotSatisfiable`] error if requested
109    /// range in header was not satisfiable.
110    pub fn try_respond(self) -> Result<RangedResponse<B>, RangeNotSatisfiable> {
111        let total_bytes = self.body.byte_size();
112
113        // we don't support multiple byte ranges, only none or one
114        // fortunately, only responding with one of the requested ranges and
115        // no more seems to be compliant with the HTTP spec.
116        let range = self.range.and_then(|range| {
117            range.satisfiable_ranges(total_bytes).nth(0)
118        });
119
120        // pull seek positions out of range header
121        let seek_start = match range {
122            Some((Bound::Included(seek_start), _)) => seek_start,
123            _ => 0,
124        };
125
126        let seek_end_excl = match range {
127            // HTTP byte ranges are inclusive, so we translate to exclusive by adding 1:
128            Some((_, Bound::Included(end))) => end + 1,
129            _ => total_bytes,
130        };
131
132        // check seek positions and return with 416 Range Not Satisfiable if invalid
133        let seek_start_beyond_seek_end = seek_start > seek_end_excl;
134        let seek_end_beyond_file_range = seek_end_excl > total_bytes;
135        // we could use >= above but I think this reads more clearly:
136        let zero_length_range = seek_start == seek_end_excl;
137
138        if seek_start_beyond_seek_end || seek_end_beyond_file_range || zero_length_range {
139            let content_range = ContentRange::unsatisfied_bytes(total_bytes);
140            return Err(RangeNotSatisfiable(content_range));
141        }
142
143        // if we're good, build the response
144        let content_range = range.map(|_| {
145            ContentRange::bytes(seek_start..seek_end_excl, total_bytes)
146                .expect("ContentRange::bytes cannot panic in this usage")
147        });
148
149        let content_length = ContentLength(seek_end_excl - seek_start);
150
151        let stream = RangedStream::new(self.body, seek_start, content_length.0);
152
153        Ok(RangedResponse {
154            content_range,
155            content_length,
156            stream,
157        })
158    }
159}
160
161impl<B: RangeBody + Send + 'static> IntoResponse for Ranged<B> {
162    fn into_response(self) -> Response {
163        self.try_respond().into_response()
164    }
165}
166
167/// Error type indicating that the requested range was not satisfiable. Implements [`IntoResponse`].
168#[derive(Debug, Clone)]
169pub struct RangeNotSatisfiable(pub ContentRange);
170
171impl IntoResponse for RangeNotSatisfiable {
172    fn into_response(self) -> Response {
173        let status = StatusCode::RANGE_NOT_SATISFIABLE;
174        let header = TypedHeader(self.0);
175        (status, header, ()).into_response()
176    }
177}
178
179/// Data type containing computed headers and body for a range response. Implements [`IntoResponse`].
180pub struct RangedResponse<B> {
181    pub content_range: Option<ContentRange>,
182    pub content_length: ContentLength,
183    pub stream: RangedStream<B>,
184}
185
186impl<B: RangeBody + Send + 'static> IntoResponse for RangedResponse<B> {
187    fn into_response(self) -> Response {
188        let content_range = self.content_range.map(TypedHeader);
189        let content_length = TypedHeader(self.content_length);
190        let accept_ranges = TypedHeader(AcceptRanges::bytes());
191        let stream = self.stream;
192
193        let status = match content_range {
194            Some(_) => StatusCode::PARTIAL_CONTENT,
195            None => StatusCode::OK,
196        };
197
198        (status, content_range, content_length, accept_ranges, stream).into_response()
199    }
200}
201
202#[cfg(test)]
203mod tests {
204    use std::io;
205
206    use axum::http::HeaderValue;
207    use axum_extra::headers::{ContentRange, Header, Range};
208    use bytes::Bytes;
209    use futures::{pin_mut, Stream, StreamExt};
210    use tokio::fs::File;
211
212    use crate::Ranged;
213    use crate::KnownSize;
214
215    async fn collect_stream(stream: impl Stream<Item = io::Result<Bytes>>) -> String {
216        let mut string = String::new();
217        pin_mut!(stream);
218        while let Some(chunk) = stream.next().await.transpose().unwrap() {
219            string += std::str::from_utf8(&chunk).unwrap();
220        }
221        string
222    }
223
224    fn range(header: &str) -> Option<Range> {
225        let val = HeaderValue::from_str(header).unwrap();
226        Some(Range::decode(&mut [val].iter()).unwrap())
227    }
228
229    async fn body() -> KnownSize<File> {
230        let file = File::open("test/fixture.txt").await.unwrap();
231        KnownSize::file(file).await.unwrap()
232    }
233
234    #[tokio::test]
235    async fn test_full_response() {
236        let ranged = Ranged::new(None, body().await);
237
238        let response = ranged.try_respond().expect("try_respond should return Ok");
239
240        assert_eq!(54, response.content_length.0);
241        assert!(response.content_range.is_none());
242        assert_eq!("Hello world this is a file to test range requests on!\n",
243            &collect_stream(response.stream).await);
244    }
245
246    #[tokio::test]
247    async fn test_partial_response_1() {
248        let ranged = Ranged::new(range("bytes=0-29"), body().await);
249
250        let response = ranged.try_respond().expect("try_respond should return Ok");
251
252        assert_eq!(30, response.content_length.0);
253
254        let expected_content_range = ContentRange::bytes(0..30, 54).unwrap();
255        assert_eq!(Some(expected_content_range), response.content_range);
256
257        assert_eq!("Hello world this is a file to ",
258            &collect_stream(response.stream).await);
259    }
260
261    #[tokio::test]
262    async fn test_partial_response_2() {
263        let ranged = Ranged::new(range("bytes=30-53"), body().await);
264
265        let response = ranged.try_respond().expect("try_respond should return Ok");
266
267        assert_eq!(24, response.content_length.0);
268
269        let expected_content_range = ContentRange::bytes(30..54, 54).unwrap();
270        assert_eq!(Some(expected_content_range), response.content_range);
271
272        assert_eq!("test range requests on!\n",
273            &collect_stream(response.stream).await);
274    }
275
276    #[tokio::test]
277    async fn test_unbounded_start_response() {
278        // unbounded ranges in HTTP are actually a suffix
279
280        let ranged = Ranged::new(range("bytes=-20"), body().await);
281
282        let response = ranged.try_respond().expect("try_respond should return Ok");
283
284        assert_eq!(20, response.content_length.0);
285
286        let expected_content_range = ContentRange::bytes(34..54, 54).unwrap();
287        assert_eq!(Some(expected_content_range), response.content_range);
288
289        assert_eq!(" range requests on!\n",
290            &collect_stream(response.stream).await);
291    }
292
293    #[tokio::test]
294    async fn test_unbounded_end_response() {
295        let ranged = Ranged::new(range("bytes=40-"), body().await);
296
297        let response = ranged.try_respond().expect("try_respond should return Ok");
298
299        assert_eq!(14, response.content_length.0);
300
301        let expected_content_range = ContentRange::bytes(40..54, 54).unwrap();
302        assert_eq!(Some(expected_content_range), response.content_range);
303
304        assert_eq!(" requests on!\n",
305            &collect_stream(response.stream).await);
306    }
307
308    #[tokio::test]
309    async fn test_one_byte_response() {
310        let ranged = Ranged::new(range("bytes=30-30"), body().await);
311
312        let response = ranged.try_respond().expect("try_respond should return Ok");
313
314        assert_eq!(1, response.content_length.0);
315
316        let expected_content_range = ContentRange::bytes(30..31, 54).unwrap();
317        assert_eq!(Some(expected_content_range), response.content_range);
318
319        assert_eq!("t",
320            &collect_stream(response.stream).await);
321    }
322
323    #[tokio::test]
324    async fn test_invalid_range() {
325        let ranged = Ranged::new(range("bytes=30-29"), body().await);
326
327        let err = ranged.try_respond().err().expect("try_respond should return Err");
328
329        let expected_content_range = ContentRange::unsatisfied_bytes(54);
330        assert_eq!(expected_content_range, err.0)
331    }
332
333    #[tokio::test]
334    async fn test_range_end_exceed_length() {
335        let ranged = Ranged::new(range("bytes=30-99"), body().await);
336
337        let err = ranged.try_respond().err().expect("try_respond should return Err");
338
339        let expected_content_range = ContentRange::unsatisfied_bytes(54);
340        assert_eq!(expected_content_range, err.0)
341    }
342
343    #[tokio::test]
344    async fn test_range_start_exceed_length() {
345        let ranged = Ranged::new(range("bytes=99-"), body().await);
346
347        let err = ranged.try_respond().err().expect("try_respond should return Err");
348
349        let expected_content_range = ContentRange::unsatisfied_bytes(54);
350        assert_eq!(expected_content_range, err.0)
351    }
352}