1use std::future::Future;
2use std::io;
3use std::pin::Pin;
4use std::task::{Context, Poll};
5use std::time::Duration;
6
7use hyper::rt::{Read, Write};
8use tokio::time::timeout;
9
10use hyper::Uri;
11use hyper_util::client::legacy::connect::{Connected, Connection};
12use tower_service::Service;
13
14mod stream;
15use stream::TimeoutStream;
16
17type BoxError = Box<dyn std::error::Error + Send + Sync>;
18
19#[derive(Debug, Clone)]
21pub struct TimeoutConnector<T> {
22 connector: T,
24 connect_timeout: Option<Duration>,
26 read_timeout: Option<Duration>,
28 write_timeout: Option<Duration>,
30 reset_reader_on_write: bool,
32}
33
34impl<T> TimeoutConnector<T>
35where
36 T: Service<Uri> + Send,
37 T::Response: Read + Write + Send + Unpin,
38 T::Future: Send + 'static,
39 T::Error: Into<BoxError>,
40{
41 pub fn new(connector: T) -> Self {
43 TimeoutConnector {
44 connector,
45 connect_timeout: None,
46 read_timeout: None,
47 write_timeout: None,
48 reset_reader_on_write: false,
49 }
50 }
51}
52
53impl<T> Service<Uri> for TimeoutConnector<T>
54where
55 T: Service<Uri> + Send,
56 T::Response: Read + Write + Connection + Send + Unpin,
57 T::Future: Send + 'static,
58 T::Error: Into<BoxError>,
59{
60 type Response = Pin<Box<TimeoutStream<T::Response>>>;
61 type Error = BoxError;
62 #[allow(clippy::type_complexity)]
63 type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
64
65 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
66 self.connector.poll_ready(cx).map_err(Into::into)
67 }
68
69 fn call(&mut self, dst: Uri) -> Self::Future {
70 let connect_timeout = self.connect_timeout;
71 let read_timeout = self.read_timeout;
72 let write_timeout = self.write_timeout;
73 let reset_reader_on_write = self.reset_reader_on_write;
74 let connecting = self.connector.call(dst);
75
76 let fut = async move {
77 let mut stream = match connect_timeout {
78 None => {
79 let io = connecting.await.map_err(Into::into)?;
80 TimeoutStream::new(io)
81 }
82 Some(connect_timeout) => {
83 let timeout = timeout(connect_timeout, connecting);
84 let connecting = timeout
85 .await
86 .map_err(|e| io::Error::new(io::ErrorKind::TimedOut, e))?;
87 let io = connecting.map_err(Into::into)?;
88 TimeoutStream::new(io)
89 }
90 };
91 stream.set_read_timeout(read_timeout);
92 stream.set_write_timeout(write_timeout);
93 stream.set_reset_reader_on_write(reset_reader_on_write);
94 Ok(Box::pin(stream))
95 };
96
97 Box::pin(fut)
98 }
99}
100
101impl<T> TimeoutConnector<T> {
102 #[inline]
106 pub fn set_connect_timeout(&mut self, val: Option<Duration>) {
107 self.connect_timeout = val;
108 }
109
110 #[inline]
114 pub fn set_read_timeout(&mut self, val: Option<Duration>) {
115 self.read_timeout = val;
116 }
117
118 #[inline]
122 pub fn set_write_timeout(&mut self, val: Option<Duration>) {
123 self.write_timeout = val;
124 }
125
126 pub fn set_reset_reader_on_write(&mut self, reset: bool) {
132 self.reset_reader_on_write = reset;
133 }
134}
135
136impl<T> Connection for TimeoutConnector<T>
137where
138 T: Read + Write + Connection + Service<Uri> + Send + Unpin,
139 T::Response: Read + Write + Send + Unpin,
140 T::Future: Send + 'static,
141 T::Error: Into<BoxError>,
142{
143 fn connected(&self) -> Connected {
144 self.connector.connected()
145 }
146}
147
148#[cfg(test)]
149mod tests {
150 use std::time::Duration;
151 use std::{error::Error, io};
152
153 use http_body_util::Empty;
154 use hyper::body::Bytes;
155 use hyper_util::{
156 client::legacy::{connect::HttpConnector, Client},
157 rt::TokioExecutor,
158 };
159
160 use super::TimeoutConnector;
161
162 #[tokio::test]
163 async fn test_timeout_connector() {
164 let url = "http://10.255.255.1".parse().unwrap();
166
167 let http = HttpConnector::new();
168 let mut connector = TimeoutConnector::new(http);
169 connector.set_connect_timeout(Some(Duration::from_millis(1)));
170
171 let client = Client::builder(TokioExecutor::new()).build::<_, Empty<Bytes>>(connector);
172
173 let res = client.get(url).await;
174
175 match res {
176 Ok(_) => panic!("Expected a timeout"),
177 Err(e) => {
178 if let Some(io_e) = e.source().unwrap().downcast_ref::<io::Error>() {
179 assert_eq!(io_e.kind(), io::ErrorKind::TimedOut);
180 } else {
181 panic!("Expected timeout error");
182 }
183 }
184 }
185 }
186
187 #[tokio::test]
188 async fn test_read_timeout() {
189 let url = "http://example.com".parse().unwrap();
190
191 let http = HttpConnector::new();
192 let mut connector = TimeoutConnector::new(http);
193 connector.set_read_timeout(Some(Duration::from_millis(1)));
195
196 let client = Client::builder(TokioExecutor::new()).build::<_, Empty<Bytes>>(connector);
197
198 let res = client.get(url).await;
199
200 if let Err(client_e) = res {
201 if let Some(hyper_e) = client_e.source() {
202 if let Some(io_e) = hyper_e.source().unwrap().downcast_ref::<io::Error>() {
203 return assert_eq!(io_e.kind(), io::ErrorKind::TimedOut);
204 }
205 }
206 }
207 panic!("Expected timeout error");
208 }
209}