1use crate::client::backoff::{Backoff, BackoffConfig};
21use crate::PutPayload;
22use futures::future::BoxFuture;
23use reqwest::header::LOCATION;
24use reqwest::{Client, Request, Response, StatusCode};
25use snafu::Error as SnafuError;
26use snafu::Snafu;
27use std::time::{Duration, Instant};
28use tracing::info;
29
30#[derive(Debug, Snafu)]
32pub enum Error {
33 #[snafu(display("Received redirect without LOCATION, this normally indicates an incorrectly configured region"))]
34 BareRedirect,
35
36 #[snafu(display("Client error with status {status}: {}", body.as_deref().unwrap_or("No Body")))]
37 Client {
38 status: StatusCode,
39 body: Option<String>,
40 },
41
42 #[snafu(display("Error after {retries} retries in {elapsed:?}, max_retries:{max_retries}, retry_timeout:{retry_timeout:?}, source:{source}"))]
43 Reqwest {
44 retries: usize,
45 max_retries: usize,
46 elapsed: Duration,
47 retry_timeout: Duration,
48 source: reqwest::Error,
49 },
50}
51
52impl Error {
53 pub fn status(&self) -> Option<StatusCode> {
55 match self {
56 Self::BareRedirect => None,
57 Self::Client { status, .. } => Some(*status),
58 Self::Reqwest { source, .. } => source.status(),
59 }
60 }
61
62 pub fn body(&self) -> Option<&str> {
64 match self {
65 Self::Client { body, .. } => body.as_deref(),
66 Self::BareRedirect => None,
67 Self::Reqwest { .. } => None,
68 }
69 }
70
71 pub fn error(self, store: &'static str, path: String) -> crate::Error {
72 match self.status() {
73 Some(StatusCode::NOT_FOUND) => crate::Error::NotFound {
74 path,
75 source: Box::new(self),
76 },
77 Some(StatusCode::NOT_MODIFIED) => crate::Error::NotModified {
78 path,
79 source: Box::new(self),
80 },
81 Some(StatusCode::PRECONDITION_FAILED) => crate::Error::Precondition {
82 path,
83 source: Box::new(self),
84 },
85 Some(StatusCode::CONFLICT) => crate::Error::AlreadyExists {
86 path,
87 source: Box::new(self),
88 },
89 _ => crate::Error::Generic {
90 store,
91 source: Box::new(self),
92 },
93 }
94 }
95}
96
97impl From<Error> for std::io::Error {
98 fn from(err: Error) -> Self {
99 use std::io::ErrorKind;
100 match &err {
101 Error::Client {
102 status: StatusCode::NOT_FOUND,
103 ..
104 } => Self::new(ErrorKind::NotFound, err),
105 Error::Client {
106 status: StatusCode::BAD_REQUEST,
107 ..
108 } => Self::new(ErrorKind::InvalidInput, err),
109 Error::Reqwest { source, .. } if source.is_timeout() => {
110 Self::new(ErrorKind::TimedOut, err)
111 }
112 Error::Reqwest { source, .. } if source.is_connect() => {
113 Self::new(ErrorKind::NotConnected, err)
114 }
115 _ => Self::new(ErrorKind::Other, err),
116 }
117 }
118}
119
120pub type Result<T, E = Error> = std::result::Result<T, E>;
121
122#[derive(Debug, Clone)]
136pub struct RetryConfig {
137 pub backoff: BackoffConfig,
139
140 pub max_retries: usize,
144
145 pub retry_timeout: Duration,
157}
158
159impl Default for RetryConfig {
160 fn default() -> Self {
161 Self {
162 backoff: Default::default(),
163 max_retries: 10,
164 retry_timeout: Duration::from_secs(3 * 60),
165 }
166 }
167}
168
169pub struct RetryableRequest {
170 client: Client,
171 request: Request,
172
173 max_retries: usize,
174 retry_timeout: Duration,
175 backoff: Backoff,
176
177 sensitive: bool,
178 idempotent: Option<bool>,
179 payload: Option<PutPayload>,
180}
181
182impl RetryableRequest {
183 pub fn idempotent(self, idempotent: bool) -> Self {
188 Self {
189 idempotent: Some(idempotent),
190 ..self
191 }
192 }
193
194 #[allow(unused)]
198 pub fn sensitive(self, sensitive: bool) -> Self {
199 Self { sensitive, ..self }
200 }
201
202 pub fn payload(self, payload: Option<PutPayload>) -> Self {
204 Self { payload, ..self }
205 }
206
207 pub async fn send(self) -> Result<Response> {
208 let max_retries = self.max_retries;
209 let retry_timeout = self.retry_timeout;
210 let mut retries = 0;
211 let now = Instant::now();
212
213 let mut backoff = self.backoff;
214 let is_idempotent = self
215 .idempotent
216 .unwrap_or_else(|| self.request.method().is_safe());
217
218 let sanitize_err = move |e: reqwest::Error| match self.sensitive {
219 true => e.without_url(),
220 false => e,
221 };
222
223 loop {
224 let mut request = self
225 .request
226 .try_clone()
227 .expect("request body must be cloneable");
228
229 if let Some(payload) = &self.payload {
230 *request.body_mut() = Some(payload.body());
231 }
232
233 match self.client.execute(request).await {
234 Ok(r) => match r.error_for_status_ref() {
235 Ok(_) if r.status().is_success() => return Ok(r),
236 Ok(r) if r.status() == StatusCode::NOT_MODIFIED => {
237 return Err(Error::Client {
238 body: None,
239 status: StatusCode::NOT_MODIFIED,
240 })
241 }
242 Ok(r) => {
243 let is_bare_redirect =
244 r.status().is_redirection() && !r.headers().contains_key(LOCATION);
245 return match is_bare_redirect {
246 true => Err(Error::BareRedirect),
247 false => Err(Error::Client {
249 body: None,
250 status: r.status(),
251 }),
252 };
253 }
254 Err(e) => {
255 let e = sanitize_err(e);
256 let status = r.status();
257 if retries == max_retries
258 || now.elapsed() > retry_timeout
259 || !status.is_server_error()
260 {
261 return Err(match status.is_client_error() {
262 true => match r.text().await {
263 Ok(body) => Error::Client {
264 body: Some(body).filter(|b| !b.is_empty()),
265 status,
266 },
267 Err(e) => Error::Reqwest {
268 retries,
269 max_retries,
270 elapsed: now.elapsed(),
271 retry_timeout,
272 source: e,
273 },
274 },
275 false => Error::Reqwest {
276 retries,
277 max_retries,
278 elapsed: now.elapsed(),
279 retry_timeout,
280 source: e,
281 },
282 });
283 }
284
285 let sleep = backoff.next();
286 retries += 1;
287 info!(
288 "Encountered server error, backing off for {} seconds, retry {} of {}: {}",
289 sleep.as_secs_f32(),
290 retries,
291 max_retries,
292 e,
293 );
294 tokio::time::sleep(sleep).await;
295 }
296 },
297 Err(e) => {
298 let e = sanitize_err(e);
299
300 let mut do_retry = false;
301 if e.is_connect()
302 || e.is_body()
303 || (e.is_request() && !e.is_timeout())
304 || (is_idempotent && e.is_timeout())
305 {
306 do_retry = true
307 } else {
308 let mut source = e.source();
309 while let Some(e) = source {
310 if let Some(e) = e.downcast_ref::<hyper::Error>() {
311 do_retry = e.is_closed()
312 || e.is_incomplete_message()
313 || e.is_body_write_aborted()
314 || (is_idempotent && e.is_timeout());
315 break;
316 }
317 if let Some(e) = e.downcast_ref::<std::io::Error>() {
318 if e.kind() == std::io::ErrorKind::TimedOut {
319 do_retry = is_idempotent;
320 } else {
321 do_retry = matches!(
322 e.kind(),
323 std::io::ErrorKind::ConnectionReset
324 | std::io::ErrorKind::ConnectionAborted
325 | std::io::ErrorKind::BrokenPipe
326 | std::io::ErrorKind::UnexpectedEof
327 );
328 }
329 break;
330 }
331 source = e.source();
332 }
333 }
334
335 if retries == max_retries || now.elapsed() > retry_timeout || !do_retry {
336 return Err(Error::Reqwest {
337 retries,
338 max_retries,
339 elapsed: now.elapsed(),
340 retry_timeout,
341 source: e,
342 });
343 }
344 let sleep = backoff.next();
345 retries += 1;
346 info!(
347 "Encountered transport error backing off for {} seconds, retry {} of {}: {}",
348 sleep.as_secs_f32(),
349 retries,
350 max_retries,
351 e,
352 );
353 tokio::time::sleep(sleep).await;
354 }
355 }
356 }
357 }
358}
359
360pub trait RetryExt {
361 fn retryable(self, config: &RetryConfig) -> RetryableRequest;
363
364 fn send_retry(self, config: &RetryConfig) -> BoxFuture<'static, Result<Response>>;
370}
371
372impl RetryExt for reqwest::RequestBuilder {
373 fn retryable(self, config: &RetryConfig) -> RetryableRequest {
374 let (client, request) = self.build_split();
375 let request = request.expect("request must be valid");
376
377 RetryableRequest {
378 client,
379 request,
380 max_retries: config.max_retries,
381 retry_timeout: config.retry_timeout,
382 backoff: Backoff::new(&config.backoff),
383 idempotent: None,
384 payload: None,
385 sensitive: false,
386 }
387 }
388
389 fn send_retry(self, config: &RetryConfig) -> BoxFuture<'static, Result<Response>> {
390 let request = self.retryable(config);
391 Box::pin(async move { request.send().await })
392 }
393}
394
395#[cfg(test)]
396mod tests {
397 use crate::client::mock_server::MockServer;
398 use crate::client::retry::{Error, RetryExt};
399 use crate::RetryConfig;
400 use hyper::header::LOCATION;
401 use hyper::Response;
402 use reqwest::{Client, Method, StatusCode};
403 use std::time::Duration;
404
405 #[tokio::test]
406 async fn test_retry() {
407 let mock = MockServer::new().await;
408
409 let retry = RetryConfig {
410 backoff: Default::default(),
411 max_retries: 2,
412 retry_timeout: Duration::from_secs(1000),
413 };
414
415 let client = Client::builder()
416 .timeout(Duration::from_millis(100))
417 .build()
418 .unwrap();
419
420 let do_request = || client.request(Method::GET, mock.url()).send_retry(&retry);
421
422 let r = do_request().await.unwrap();
424 assert_eq!(r.status(), StatusCode::OK);
425
426 mock.push(
428 Response::builder()
429 .status(StatusCode::BAD_REQUEST)
430 .body("cupcakes".to_string())
431 .unwrap(),
432 );
433
434 let e = do_request().await.unwrap_err();
435 assert_eq!(e.status().unwrap(), StatusCode::BAD_REQUEST);
436 assert_eq!(e.body(), Some("cupcakes"));
437 assert_eq!(
438 e.to_string(),
439 "Client error with status 400 Bad Request: cupcakes"
440 );
441
442 mock.push(
444 Response::builder()
445 .status(StatusCode::BAD_REQUEST)
446 .body(String::new())
447 .unwrap(),
448 );
449
450 let e = do_request().await.unwrap_err();
451 assert_eq!(e.status().unwrap(), StatusCode::BAD_REQUEST);
452 assert_eq!(e.body(), None);
453 assert_eq!(
454 e.to_string(),
455 "Client error with status 400 Bad Request: No Body"
456 );
457
458 mock.push(
460 Response::builder()
461 .status(StatusCode::BAD_GATEWAY)
462 .body(String::new())
463 .unwrap(),
464 );
465
466 let r = do_request().await.unwrap();
467 assert_eq!(r.status(), StatusCode::OK);
468
469 mock.push(
471 Response::builder()
472 .status(StatusCode::NO_CONTENT)
473 .body(String::new())
474 .unwrap(),
475 );
476
477 let r = do_request().await.unwrap();
478 assert_eq!(r.status(), StatusCode::NO_CONTENT);
479
480 mock.push(
482 Response::builder()
483 .status(StatusCode::FOUND)
484 .header(LOCATION, "/foo")
485 .body(String::new())
486 .unwrap(),
487 );
488
489 let r = do_request().await.unwrap();
490 assert_eq!(r.status(), StatusCode::OK);
491 assert_eq!(r.url().path(), "/foo");
492
493 mock.push(
495 Response::builder()
496 .status(StatusCode::FOUND)
497 .header(LOCATION, "/bar")
498 .body(String::new())
499 .unwrap(),
500 );
501
502 let r = do_request().await.unwrap();
503 assert_eq!(r.status(), StatusCode::OK);
504 assert_eq!(r.url().path(), "/bar");
505
506 for _ in 0..10 {
508 mock.push(
509 Response::builder()
510 .status(StatusCode::FOUND)
511 .header(LOCATION, "/bar")
512 .body(String::new())
513 .unwrap(),
514 );
515 }
516
517 let e = do_request().await.unwrap_err().to_string();
518 assert!(e.contains("error following redirect for url"), "{}", e);
519
520 mock.push(
522 Response::builder()
523 .status(StatusCode::FOUND)
524 .body(String::new())
525 .unwrap(),
526 );
527
528 let e = do_request().await.unwrap_err();
529 assert!(matches!(e, Error::BareRedirect));
530 assert_eq!(e.to_string(), "Received redirect without LOCATION, this normally indicates an incorrectly configured region");
531
532 for _ in 0..=retry.max_retries {
534 mock.push(
535 Response::builder()
536 .status(StatusCode::BAD_GATEWAY)
537 .body("ignored".to_string())
538 .unwrap(),
539 );
540 }
541
542 let e = do_request().await.unwrap_err().to_string();
543 assert!(
544 e.contains("Error after 2 retries in") &&
545 e.contains("max_retries:2, retry_timeout:1000s, source:HTTP status server error (502 Bad Gateway) for url"),
546 "{e}"
547 );
548
549 mock.push_fn(|_| panic!());
551 let r = do_request().await.unwrap();
552 assert_eq!(r.status(), StatusCode::OK);
553
554 for _ in 0..=retry.max_retries {
556 mock.push_fn(|_| panic!());
557 }
558 let e = do_request().await.unwrap_err().to_string();
559 assert!(
560 e.contains("Error after 2 retries in")
561 && e.contains(
562 "max_retries:2, retry_timeout:1000s, source:error sending request for url"
563 ),
564 "{e}"
565 );
566
567 mock.push_async_fn(|_| async move {
569 tokio::time::sleep(Duration::from_secs(10)).await;
570 panic!()
571 });
572 do_request().await.unwrap();
573
574 mock.push_async_fn(|_| async move {
576 tokio::time::sleep(Duration::from_secs(10)).await;
577 panic!()
578 });
579 let res = client.request(Method::PUT, mock.url()).send_retry(&retry);
580 let e = res.await.unwrap_err().to_string();
581 assert!(
582 e.contains("Error after 0 retries in") && e.contains("error sending request for url"),
583 "{e}"
584 );
585
586 let url = format!("{}/SENSITIVE", mock.url());
587 for _ in 0..=retry.max_retries {
588 mock.push(
589 Response::builder()
590 .status(StatusCode::BAD_GATEWAY)
591 .body("ignored".to_string())
592 .unwrap(),
593 );
594 }
595 let res = client.request(Method::GET, url).send_retry(&retry).await;
596 let err = res.unwrap_err().to_string();
597 assert!(err.contains("SENSITIVE"), "{err}");
598
599 let url = format!("{}/SENSITIVE", mock.url());
600 for _ in 0..=retry.max_retries {
601 mock.push(
602 Response::builder()
603 .status(StatusCode::BAD_GATEWAY)
604 .body("ignored".to_string())
605 .unwrap(),
606 );
607 }
608
609 let req = client
611 .request(Method::GET, &url)
612 .retryable(&retry)
613 .sensitive(true);
614 let err = req.send().await.unwrap_err().to_string();
615 assert!(!err.contains("SENSITIVE"), "{err}");
616
617 for _ in 0..=retry.max_retries {
618 mock.push_fn(|_| panic!());
619 }
620
621 let req = client
622 .request(Method::GET, &url)
623 .retryable(&retry)
624 .sensitive(true);
625 let err = req.send().await.unwrap_err().to_string();
626 assert!(!err.contains("SENSITIVE"), "{err}");
627
628 mock.shutdown().await
630 }
631}