hyper_timeout/
lib.rs

1use std::future::Future;
2use std::io;
3use std::pin::Pin;
4use std::task::{Context, Poll};
5use std::time::Duration;
6
7use hyper::rt::{Read, Write};
8use tokio::time::timeout;
9
10use hyper::Uri;
11use hyper_util::client::legacy::connect::{Connected, Connection};
12use tower_service::Service;
13
14mod stream;
15use stream::TimeoutStream;
16
17type BoxError = Box<dyn std::error::Error + Send + Sync>;
18
19/// A connector that enforces a connection timeout
20#[derive(Debug, Clone)]
21pub struct TimeoutConnector<T> {
22    /// A connector implementing the `Connect` trait
23    connector: T,
24    /// Amount of time to wait connecting
25    connect_timeout: Option<Duration>,
26    /// Amount of time to wait reading response
27    read_timeout: Option<Duration>,
28    /// Amount of time to wait writing request
29    write_timeout: Option<Duration>,
30    /// If true, resets the reader timeout whenever a write occures
31    reset_reader_on_write: bool,
32}
33
34impl<T> TimeoutConnector<T>
35where
36    T: Service<Uri> + Send,
37    T::Response: Read + Write + Send + Unpin,
38    T::Future: Send + 'static,
39    T::Error: Into<BoxError>,
40{
41    /// Construct a new TimeoutConnector with a given connector implementing the `Connect` trait
42    pub fn new(connector: T) -> Self {
43        TimeoutConnector {
44            connector,
45            connect_timeout: None,
46            read_timeout: None,
47            write_timeout: None,
48            reset_reader_on_write: false,
49        }
50    }
51}
52
53impl<T> Service<Uri> for TimeoutConnector<T>
54where
55    T: Service<Uri> + Send,
56    T::Response: Read + Write + Connection + Send + Unpin,
57    T::Future: Send + 'static,
58    T::Error: Into<BoxError>,
59{
60    type Response = Pin<Box<TimeoutStream<T::Response>>>;
61    type Error = BoxError;
62    #[allow(clippy::type_complexity)]
63    type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
64
65    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
66        self.connector.poll_ready(cx).map_err(Into::into)
67    }
68
69    fn call(&mut self, dst: Uri) -> Self::Future {
70        let connect_timeout = self.connect_timeout;
71        let read_timeout = self.read_timeout;
72        let write_timeout = self.write_timeout;
73        let reset_reader_on_write = self.reset_reader_on_write;
74        let connecting = self.connector.call(dst);
75
76        let fut = async move {
77            let mut stream = match connect_timeout {
78                None => {
79                    let io = connecting.await.map_err(Into::into)?;
80                    TimeoutStream::new(io)
81                }
82                Some(connect_timeout) => {
83                    let timeout = timeout(connect_timeout, connecting);
84                    let connecting = timeout
85                        .await
86                        .map_err(|e| io::Error::new(io::ErrorKind::TimedOut, e))?;
87                    let io = connecting.map_err(Into::into)?;
88                    TimeoutStream::new(io)
89                }
90            };
91            stream.set_read_timeout(read_timeout);
92            stream.set_write_timeout(write_timeout);
93            stream.set_reset_reader_on_write(reset_reader_on_write);
94            Ok(Box::pin(stream))
95        };
96
97        Box::pin(fut)
98    }
99}
100
101impl<T> TimeoutConnector<T> {
102    /// Set the timeout for connecting to a URL.
103    ///
104    /// Default is no timeout.
105    #[inline]
106    pub fn set_connect_timeout(&mut self, val: Option<Duration>) {
107        self.connect_timeout = val;
108    }
109
110    /// Set the timeout for the response.
111    ///
112    /// Default is no timeout.
113    #[inline]
114    pub fn set_read_timeout(&mut self, val: Option<Duration>) {
115        self.read_timeout = val;
116    }
117
118    /// Set the timeout for the request.
119    ///
120    /// Default is no timeout.
121    #[inline]
122    pub fn set_write_timeout(&mut self, val: Option<Duration>) {
123        self.write_timeout = val;
124    }
125
126    /// Reset on the reader timeout on write
127    ///
128    /// This will reset the reader timeout when a write is done through the
129    /// the TimeoutReader. This is useful when you don't want to trigger
130    /// a reader timeout while writes are still be accepted.
131    pub fn set_reset_reader_on_write(&mut self, reset: bool) {
132        self.reset_reader_on_write = reset;
133    }
134}
135
136impl<T> Connection for TimeoutConnector<T>
137where
138    T: Read + Write + Connection + Service<Uri> + Send + Unpin,
139    T::Response: Read + Write + Send + Unpin,
140    T::Future: Send + 'static,
141    T::Error: Into<BoxError>,
142{
143    fn connected(&self) -> Connected {
144        self.connector.connected()
145    }
146}
147
148#[cfg(test)]
149mod tests {
150    use std::time::Duration;
151    use std::{error::Error, io};
152
153    use http_body_util::Empty;
154    use hyper::body::Bytes;
155    use hyper_util::{
156        client::legacy::{connect::HttpConnector, Client},
157        rt::TokioExecutor,
158    };
159
160    use super::TimeoutConnector;
161
162    #[tokio::test]
163    async fn test_timeout_connector() {
164        // 10.255.255.1 is a not a routable IP address
165        let url = "http://10.255.255.1".parse().unwrap();
166
167        let http = HttpConnector::new();
168        let mut connector = TimeoutConnector::new(http);
169        connector.set_connect_timeout(Some(Duration::from_millis(1)));
170
171        let client = Client::builder(TokioExecutor::new()).build::<_, Empty<Bytes>>(connector);
172
173        let res = client.get(url).await;
174
175        match res {
176            Ok(_) => panic!("Expected a timeout"),
177            Err(e) => {
178                if let Some(io_e) = e.source().unwrap().downcast_ref::<io::Error>() {
179                    assert_eq!(io_e.kind(), io::ErrorKind::TimedOut);
180                } else {
181                    panic!("Expected timeout error");
182                }
183            }
184        }
185    }
186
187    #[tokio::test]
188    async fn test_read_timeout() {
189        let url = "http://example.com".parse().unwrap();
190
191        let http = HttpConnector::new();
192        let mut connector = TimeoutConnector::new(http);
193        // A 1 ms read timeout should be so short that we trigger a timeout error
194        connector.set_read_timeout(Some(Duration::from_millis(1)));
195
196        let client = Client::builder(TokioExecutor::new()).build::<_, Empty<Bytes>>(connector);
197
198        let res = client.get(url).await;
199
200        if let Err(client_e) = res {
201            if let Some(hyper_e) = client_e.source() {
202                if let Some(io_e) = hyper_e.source().unwrap().downcast_ref::<io::Error>() {
203                    return assert_eq!(io_e.kind(), io::ErrorKind::TimedOut);
204                }
205            }
206        }
207        panic!("Expected timeout error");
208    }
209}