1use crate::aws::{AwsCredentialProvider, STORE, STRICT_ENCODE_SET, STRICT_PATH_ENCODE_SET};
19use crate::client::retry::RetryExt;
20use crate::client::token::{TemporaryToken, TokenCache};
21use crate::client::TokenProvider;
22use crate::util::{hex_digest, hex_encode, hmac_sha256};
23use crate::{CredentialProvider, Result, RetryConfig};
24use async_trait::async_trait;
25use bytes::Buf;
26use chrono::{DateTime, Utc};
27use hyper::header::HeaderName;
28use percent_encoding::utf8_percent_encode;
29use reqwest::header::{HeaderMap, HeaderValue, AUTHORIZATION};
30use reqwest::{Client, Method, Request, RequestBuilder, StatusCode};
31use serde::Deserialize;
32use snafu::{ResultExt, Snafu};
33use std::collections::BTreeMap;
34use std::sync::Arc;
35use std::time::{Duration, Instant};
36use tracing::warn;
37use url::Url;
38
39#[derive(Debug, Snafu)]
40#[allow(clippy::enum_variant_names)]
41enum Error {
42 #[snafu(display("Error performing CreateSession request: {source}"))]
43 CreateSessionRequest { source: crate::client::retry::Error },
44
45 #[snafu(display("Error getting CreateSession response: {source}"))]
46 CreateSessionResponse { source: reqwest::Error },
47
48 #[snafu(display("Invalid CreateSessionOutput response: {source}"))]
49 CreateSessionOutput { source: quick_xml::DeError },
50}
51
52impl From<Error> for crate::Error {
53 fn from(value: Error) -> Self {
54 Self::Generic {
55 store: STORE,
56 source: Box::new(value),
57 }
58 }
59}
60
61type StdError = Box<dyn std::error::Error + Send + Sync>;
62
63static EMPTY_SHA256_HASH: &str = "e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855";
65static UNSIGNED_PAYLOAD: &str = "UNSIGNED-PAYLOAD";
66static STREAMING_PAYLOAD: &str = "STREAMING-AWS4-HMAC-SHA256-PAYLOAD";
67
68#[derive(Debug, Eq, PartialEq)]
70pub struct AwsCredential {
71 pub key_id: String,
73 pub secret_key: String,
75 pub token: Option<String>,
77}
78
79impl AwsCredential {
80 fn sign(&self, to_sign: &str, date: DateTime<Utc>, region: &str, service: &str) -> String {
84 let date_string = date.format("%Y%m%d").to_string();
85 let date_hmac = hmac_sha256(format!("AWS4{}", self.secret_key), date_string);
86 let region_hmac = hmac_sha256(date_hmac, region);
87 let service_hmac = hmac_sha256(region_hmac, service);
88 let signing_hmac = hmac_sha256(service_hmac, b"aws4_request");
89 hex_encode(hmac_sha256(signing_hmac, to_sign).as_ref())
90 }
91}
92
93#[derive(Debug)]
97pub struct AwsAuthorizer<'a> {
98 date: Option<DateTime<Utc>>,
99 credential: &'a AwsCredential,
100 service: &'a str,
101 region: &'a str,
102 token_header: Option<HeaderName>,
103 sign_payload: bool,
104}
105
106static DATE_HEADER: HeaderName = HeaderName::from_static("x-amz-date");
107static HASH_HEADER: HeaderName = HeaderName::from_static("x-amz-content-sha256");
108static TOKEN_HEADER: HeaderName = HeaderName::from_static("x-amz-security-token");
109const ALGORITHM: &str = "AWS4-HMAC-SHA256";
110
111impl<'a> AwsAuthorizer<'a> {
112 pub fn new(credential: &'a AwsCredential, service: &'a str, region: &'a str) -> Self {
114 Self {
115 credential,
116 service,
117 region,
118 date: None,
119 sign_payload: true,
120 token_header: None,
121 }
122 }
123
124 pub fn with_sign_payload(mut self, signed: bool) -> Self {
127 self.sign_payload = signed;
128 self
129 }
130
131 pub(crate) fn with_token_header(mut self, header: HeaderName) -> Self {
133 self.token_header = Some(header);
134 self
135 }
136
137 pub fn authorize(&self, request: &mut Request, pre_calculated_digest: Option<&[u8]>) {
151 if let Some(ref token) = self.credential.token {
152 let token_val = HeaderValue::from_str(token).unwrap();
153 let header = self.token_header.as_ref().unwrap_or(&TOKEN_HEADER);
154 request.headers_mut().insert(header, token_val);
155 }
156
157 let host = &request.url()[url::Position::BeforeHost..url::Position::AfterPort];
158 let host_val = HeaderValue::from_str(host).unwrap();
159 request.headers_mut().insert("host", host_val);
160
161 let date = self.date.unwrap_or_else(Utc::now);
162 let date_str = date.format("%Y%m%dT%H%M%SZ").to_string();
163 let date_val = HeaderValue::from_str(&date_str).unwrap();
164 request.headers_mut().insert(&DATE_HEADER, date_val);
165
166 let digest = match self.sign_payload {
167 false => UNSIGNED_PAYLOAD.to_string(),
168 true => match pre_calculated_digest {
169 Some(digest) => hex_encode(digest),
170 None => match request.body() {
171 None => EMPTY_SHA256_HASH.to_string(),
172 Some(body) => match body.as_bytes() {
173 Some(bytes) => hex_digest(bytes),
174 None => STREAMING_PAYLOAD.to_string(),
175 },
176 },
177 },
178 };
179
180 let header_digest = HeaderValue::from_str(&digest).unwrap();
181 request.headers_mut().insert(&HASH_HEADER, header_digest);
182
183 let (signed_headers, canonical_headers) = canonicalize_headers(request.headers());
184
185 let scope = self.scope(date);
186
187 let string_to_sign = self.string_to_sign(
188 date,
189 &scope,
190 request.method(),
191 request.url(),
192 &canonical_headers,
193 &signed_headers,
194 &digest,
195 );
196
197 let signature = self
199 .credential
200 .sign(&string_to_sign, date, self.region, self.service);
201
202 let authorisation = format!(
204 "{} Credential={}/{}, SignedHeaders={}, Signature={}",
205 ALGORITHM, self.credential.key_id, scope, signed_headers, signature
206 );
207
208 let authorization_val = HeaderValue::from_str(&authorisation).unwrap();
209 request
210 .headers_mut()
211 .insert(&AUTHORIZATION, authorization_val);
212 }
213
214 pub(crate) fn sign(&self, method: Method, url: &mut Url, expires_in: Duration) {
215 let date = self.date.unwrap_or_else(Utc::now);
216 let scope = self.scope(date);
217
218 url.query_pairs_mut()
220 .append_pair("X-Amz-Algorithm", ALGORITHM)
221 .append_pair(
222 "X-Amz-Credential",
223 &format!("{}/{}", self.credential.key_id, scope),
224 )
225 .append_pair("X-Amz-Date", &date.format("%Y%m%dT%H%M%SZ").to_string())
226 .append_pair("X-Amz-Expires", &expires_in.as_secs().to_string())
227 .append_pair("X-Amz-SignedHeaders", "host");
228
229 if let Some(ref token) = self.credential.token {
232 url.query_pairs_mut()
233 .append_pair("X-Amz-Security-Token", token);
234 }
235
236 let digest = UNSIGNED_PAYLOAD;
238
239 let host = &url[url::Position::BeforeHost..url::Position::AfterPort].to_string();
240 let mut headers = HeaderMap::new();
241 let host_val = HeaderValue::from_str(host).unwrap();
242 headers.insert("host", host_val);
243
244 let (signed_headers, canonical_headers) = canonicalize_headers(&headers);
245
246 let string_to_sign = self.string_to_sign(
247 date,
248 &scope,
249 &method,
250 url,
251 &canonical_headers,
252 &signed_headers,
253 digest,
254 );
255
256 let signature = self
257 .credential
258 .sign(&string_to_sign, date, self.region, self.service);
259
260 url.query_pairs_mut()
261 .append_pair("X-Amz-Signature", &signature);
262 }
263
264 #[allow(clippy::too_many_arguments)]
265 fn string_to_sign(
266 &self,
267 date: DateTime<Utc>,
268 scope: &str,
269 request_method: &Method,
270 url: &Url,
271 canonical_headers: &str,
272 signed_headers: &str,
273 digest: &str,
274 ) -> String {
275 let canonical_uri = match self.service {
279 "s3" => url.path().to_string(),
280 _ => utf8_percent_encode(url.path(), &STRICT_PATH_ENCODE_SET).to_string(),
281 };
282
283 let canonical_query = canonicalize_query(url);
284
285 let canonical_request = format!(
287 "{}\n{}\n{}\n{}\n{}\n{}",
288 request_method.as_str(),
289 canonical_uri,
290 canonical_query,
291 canonical_headers,
292 signed_headers,
293 digest
294 );
295
296 let hashed_canonical_request = hex_digest(canonical_request.as_bytes());
297
298 format!(
299 "{}\n{}\n{}\n{}",
300 ALGORITHM,
301 date.format("%Y%m%dT%H%M%SZ"),
302 scope,
303 hashed_canonical_request
304 )
305 }
306
307 fn scope(&self, date: DateTime<Utc>) -> String {
308 format!(
309 "{}/{}/{}/aws4_request",
310 date.format("%Y%m%d"),
311 self.region,
312 self.service
313 )
314 }
315}
316
317pub trait CredentialExt {
318 fn with_aws_sigv4(
320 self,
321 authorizer: Option<AwsAuthorizer<'_>>,
322 payload_sha256: Option<&[u8]>,
323 ) -> Self;
324}
325
326impl CredentialExt for RequestBuilder {
327 fn with_aws_sigv4(
328 self,
329 authorizer: Option<AwsAuthorizer<'_>>,
330 payload_sha256: Option<&[u8]>,
331 ) -> Self {
332 match authorizer {
333 Some(authorizer) => {
334 let (client, request) = self.build_split();
335 let mut request = request.expect("request valid");
336 authorizer.authorize(&mut request, payload_sha256);
337
338 Self::from_parts(client, request)
339 }
340 None => self,
341 }
342 }
343}
344
345fn canonicalize_query(url: &Url) -> String {
349 use std::fmt::Write;
350
351 let capacity = match url.query() {
352 Some(q) if !q.is_empty() => q.len(),
353 _ => return String::new(),
354 };
355 let mut encoded = String::with_capacity(capacity + 1);
356
357 let mut headers = url.query_pairs().collect::<Vec<_>>();
358 headers.sort_unstable_by(|(a, _), (b, _)| a.cmp(b));
359
360 let mut first = true;
361 for (k, v) in headers {
362 if !first {
363 encoded.push('&');
364 }
365 first = false;
366 let _ = write!(
367 encoded,
368 "{}={}",
369 utf8_percent_encode(k.as_ref(), &STRICT_ENCODE_SET),
370 utf8_percent_encode(v.as_ref(), &STRICT_ENCODE_SET)
371 );
372 }
373 encoded
374}
375
376fn canonicalize_headers(header_map: &HeaderMap) -> (String, String) {
380 let mut headers = BTreeMap::<&str, Vec<&str>>::new();
381 let mut value_count = 0;
382 let mut value_bytes = 0;
383 let mut key_bytes = 0;
384
385 for (key, value) in header_map {
386 let key = key.as_str();
387 if ["authorization", "content-length", "user-agent"].contains(&key) {
388 continue;
389 }
390
391 let value = std::str::from_utf8(value.as_bytes()).unwrap();
392 key_bytes += key.len();
393 value_bytes += value.len();
394 value_count += 1;
395 headers.entry(key).or_default().push(value);
396 }
397
398 let mut signed_headers = String::with_capacity(key_bytes + headers.len());
399 let mut canonical_headers =
400 String::with_capacity(key_bytes + value_bytes + headers.len() + value_count);
401
402 for (header_idx, (name, values)) in headers.into_iter().enumerate() {
403 if header_idx != 0 {
404 signed_headers.push(';');
405 }
406
407 signed_headers.push_str(name);
408 canonical_headers.push_str(name);
409 canonical_headers.push(':');
410 for (value_idx, value) in values.into_iter().enumerate() {
411 if value_idx != 0 {
412 canonical_headers.push(',');
413 }
414 canonical_headers.push_str(value.trim());
415 }
416 canonical_headers.push('\n');
417 }
418
419 (signed_headers, canonical_headers)
420}
421
422#[derive(Debug)]
426pub struct InstanceCredentialProvider {
427 pub imdsv1_fallback: bool,
428 pub metadata_endpoint: String,
429}
430
431#[async_trait]
432impl TokenProvider for InstanceCredentialProvider {
433 type Credential = AwsCredential;
434
435 async fn fetch_token(
436 &self,
437 client: &Client,
438 retry: &RetryConfig,
439 ) -> Result<TemporaryToken<Arc<AwsCredential>>> {
440 instance_creds(client, retry, &self.metadata_endpoint, self.imdsv1_fallback)
441 .await
442 .map_err(|source| crate::Error::Generic {
443 store: STORE,
444 source,
445 })
446 }
447}
448
449#[derive(Debug)]
453pub struct WebIdentityProvider {
454 pub token_path: String,
455 pub role_arn: String,
456 pub session_name: String,
457 pub endpoint: String,
458}
459
460#[async_trait]
461impl TokenProvider for WebIdentityProvider {
462 type Credential = AwsCredential;
463
464 async fn fetch_token(
465 &self,
466 client: &Client,
467 retry: &RetryConfig,
468 ) -> Result<TemporaryToken<Arc<AwsCredential>>> {
469 web_identity(
470 client,
471 retry,
472 &self.token_path,
473 &self.role_arn,
474 &self.session_name,
475 &self.endpoint,
476 )
477 .await
478 .map_err(|source| crate::Error::Generic {
479 store: STORE,
480 source,
481 })
482 }
483}
484
485#[derive(Debug, Deserialize)]
486#[serde(rename_all = "PascalCase")]
487struct InstanceCredentials {
488 access_key_id: String,
489 secret_access_key: String,
490 token: String,
491 expiration: DateTime<Utc>,
492}
493
494impl From<InstanceCredentials> for AwsCredential {
495 fn from(s: InstanceCredentials) -> Self {
496 Self {
497 key_id: s.access_key_id,
498 secret_key: s.secret_access_key,
499 token: Some(s.token),
500 }
501 }
502}
503
504async fn instance_creds(
506 client: &Client,
507 retry_config: &RetryConfig,
508 endpoint: &str,
509 imdsv1_fallback: bool,
510) -> Result<TemporaryToken<Arc<AwsCredential>>, StdError> {
511 const CREDENTIALS_PATH: &str = "latest/meta-data/iam/security-credentials";
512 const AWS_EC2_METADATA_TOKEN_HEADER: &str = "X-aws-ec2-metadata-token";
513
514 let token_url = format!("{endpoint}/latest/api/token");
515
516 let token_result = client
517 .request(Method::PUT, token_url)
518 .header("X-aws-ec2-metadata-token-ttl-seconds", "600") .retryable(retry_config)
520 .idempotent(true)
521 .send()
522 .await;
523
524 let token = match token_result {
525 Ok(t) => Some(t.text().await?),
526 Err(e) if imdsv1_fallback && matches!(e.status(), Some(StatusCode::FORBIDDEN)) => {
527 warn!("received 403 from metadata endpoint, falling back to IMDSv1");
528 None
529 }
530 Err(e) => return Err(e.into()),
531 };
532
533 let role_url = format!("{endpoint}/{CREDENTIALS_PATH}/");
534 let mut role_request = client.request(Method::GET, role_url);
535
536 if let Some(token) = &token {
537 role_request = role_request.header(AWS_EC2_METADATA_TOKEN_HEADER, token);
538 }
539
540 let role = role_request.send_retry(retry_config).await?.text().await?;
541
542 let creds_url = format!("{endpoint}/{CREDENTIALS_PATH}/{role}");
543 let mut creds_request = client.request(Method::GET, creds_url);
544 if let Some(token) = &token {
545 creds_request = creds_request.header(AWS_EC2_METADATA_TOKEN_HEADER, token);
546 }
547
548 let creds: InstanceCredentials = creds_request.send_retry(retry_config).await?.json().await?;
549
550 let now = Utc::now();
551 let ttl = (creds.expiration - now).to_std().unwrap_or_default();
552 Ok(TemporaryToken {
553 token: Arc::new(creds.into()),
554 expiry: Some(Instant::now() + ttl),
555 })
556}
557
558#[derive(Debug, Deserialize)]
559#[serde(rename_all = "PascalCase")]
560struct AssumeRoleResponse {
561 assume_role_with_web_identity_result: AssumeRoleResult,
562}
563
564#[derive(Debug, Deserialize)]
565#[serde(rename_all = "PascalCase")]
566struct AssumeRoleResult {
567 credentials: SessionCredentials,
568}
569
570#[derive(Debug, Deserialize)]
571#[serde(rename_all = "PascalCase")]
572struct SessionCredentials {
573 session_token: String,
574 secret_access_key: String,
575 access_key_id: String,
576 expiration: DateTime<Utc>,
577}
578
579impl From<SessionCredentials> for AwsCredential {
580 fn from(s: SessionCredentials) -> Self {
581 Self {
582 key_id: s.access_key_id,
583 secret_key: s.secret_access_key,
584 token: Some(s.session_token),
585 }
586 }
587}
588
589async fn web_identity(
591 client: &Client,
592 retry_config: &RetryConfig,
593 token_path: &str,
594 role_arn: &str,
595 session_name: &str,
596 endpoint: &str,
597) -> Result<TemporaryToken<Arc<AwsCredential>>, StdError> {
598 let token = std::fs::read_to_string(token_path)
599 .map_err(|e| format!("Failed to read token file '{token_path}': {e}"))?;
600
601 let bytes = client
602 .request(Method::POST, endpoint)
603 .query(&[
604 ("Action", "AssumeRoleWithWebIdentity"),
605 ("DurationSeconds", "3600"),
606 ("RoleArn", role_arn),
607 ("RoleSessionName", session_name),
608 ("Version", "2011-06-15"),
609 ("WebIdentityToken", &token),
610 ])
611 .retryable(retry_config)
612 .idempotent(true)
613 .sensitive(true)
614 .send()
615 .await?
616 .bytes()
617 .await?;
618
619 let resp: AssumeRoleResponse = quick_xml::de::from_reader(bytes.reader())
620 .map_err(|e| format!("Invalid AssumeRoleWithWebIdentity response: {e}"))?;
621
622 let creds = resp.assume_role_with_web_identity_result.credentials;
623 let now = Utc::now();
624 let ttl = (creds.expiration - now).to_std().unwrap_or_default();
625
626 Ok(TemporaryToken {
627 token: Arc::new(creds.into()),
628 expiry: Some(Instant::now() + ttl),
629 })
630}
631
632#[derive(Debug)]
636pub struct TaskCredentialProvider {
637 pub url: String,
638 pub retry: RetryConfig,
639 pub client: Client,
640 pub cache: TokenCache<Arc<AwsCredential>>,
641}
642
643#[async_trait]
644impl CredentialProvider for TaskCredentialProvider {
645 type Credential = AwsCredential;
646
647 async fn get_credential(&self) -> Result<Arc<AwsCredential>> {
648 self.cache
649 .get_or_insert_with(|| task_credential(&self.client, &self.retry, &self.url))
650 .await
651 .map_err(|source| crate::Error::Generic {
652 store: STORE,
653 source,
654 })
655 }
656}
657
658async fn task_credential(
660 client: &Client,
661 retry: &RetryConfig,
662 url: &str,
663) -> Result<TemporaryToken<Arc<AwsCredential>>, StdError> {
664 let creds: InstanceCredentials = client.get(url).send_retry(retry).await?.json().await?;
665
666 let now = Utc::now();
667 let ttl = (creds.expiration - now).to_std().unwrap_or_default();
668 Ok(TemporaryToken {
669 token: Arc::new(creds.into()),
670 expiry: Some(Instant::now() + ttl),
671 })
672}
673
674#[derive(Debug)]
678pub struct SessionProvider {
679 pub endpoint: String,
680 pub region: String,
681 pub credentials: AwsCredentialProvider,
682}
683
684#[async_trait]
685impl TokenProvider for SessionProvider {
686 type Credential = AwsCredential;
687
688 async fn fetch_token(
689 &self,
690 client: &Client,
691 retry: &RetryConfig,
692 ) -> Result<TemporaryToken<Arc<Self::Credential>>> {
693 let creds = self.credentials.get_credential().await?;
694 let authorizer = AwsAuthorizer::new(&creds, "s3", &self.region);
695
696 let bytes = client
697 .get(format!("{}?session", self.endpoint))
698 .with_aws_sigv4(Some(authorizer), None)
699 .send_retry(retry)
700 .await
701 .context(CreateSessionRequestSnafu)?
702 .bytes()
703 .await
704 .context(CreateSessionResponseSnafu)?;
705
706 let resp: CreateSessionOutput =
707 quick_xml::de::from_reader(bytes.reader()).context(CreateSessionOutputSnafu)?;
708
709 let creds = resp.credentials;
710 Ok(TemporaryToken {
711 token: Arc::new(creds.into()),
712 expiry: Some(Instant::now() + Duration::from_secs(5 * 60)),
714 })
715 }
716}
717
718#[derive(Debug, Deserialize)]
719#[serde(rename_all = "PascalCase")]
720struct CreateSessionOutput {
721 credentials: SessionCredentials,
722}
723
724#[cfg(test)]
725mod tests {
726 use super::*;
727 use crate::client::mock_server::MockServer;
728 use hyper::Response;
729 use reqwest::{Client, Method};
730 use std::env;
731
732 #[test]
734 fn test_sign_with_signed_payload() {
735 let client = Client::new();
736
737 let credential = AwsCredential {
739 key_id: "AKIAIOSFODNN7EXAMPLE".to_string(),
740 secret_key: "wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY".to_string(),
741 token: None,
742 };
743
744 let date = DateTime::parse_from_rfc3339("2022-08-06T18:01:34Z")
751 .unwrap()
752 .with_timezone(&Utc);
753
754 let mut request = client
755 .request(Method::GET, "https://ec2.amazon.com/")
756 .build()
757 .unwrap();
758
759 let signer = AwsAuthorizer {
760 date: Some(date),
761 credential: &credential,
762 service: "ec2",
763 region: "us-east-1",
764 sign_payload: true,
765 token_header: None,
766 };
767
768 signer.authorize(&mut request, None);
769 assert_eq!(request.headers().get(&AUTHORIZATION).unwrap(), "AWS4-HMAC-SHA256 Credential=AKIAIOSFODNN7EXAMPLE/20220806/us-east-1/ec2/aws4_request, SignedHeaders=host;x-amz-content-sha256;x-amz-date, Signature=a3c787a7ed37f7fdfbfd2d7056a3d7c9d85e6d52a2bfbec73793c0be6e7862d4")
770 }
771
772 #[test]
773 fn test_sign_with_unsigned_payload() {
774 let client = Client::new();
775
776 let credential = AwsCredential {
778 key_id: "AKIAIOSFODNN7EXAMPLE".to_string(),
779 secret_key: "wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY".to_string(),
780 token: None,
781 };
782
783 let date = DateTime::parse_from_rfc3339("2022-08-06T18:01:34Z")
790 .unwrap()
791 .with_timezone(&Utc);
792
793 let mut request = client
794 .request(Method::GET, "https://ec2.amazon.com/")
795 .build()
796 .unwrap();
797
798 let authorizer = AwsAuthorizer {
799 date: Some(date),
800 credential: &credential,
801 service: "ec2",
802 region: "us-east-1",
803 token_header: None,
804 sign_payload: false,
805 };
806
807 authorizer.authorize(&mut request, None);
808 assert_eq!(request.headers().get(&AUTHORIZATION).unwrap(), "AWS4-HMAC-SHA256 Credential=AKIAIOSFODNN7EXAMPLE/20220806/us-east-1/ec2/aws4_request, SignedHeaders=host;x-amz-content-sha256;x-amz-date, Signature=653c3d8ea261fd826207df58bc2bb69fbb5003e9eb3c0ef06e4a51f2a81d8699");
809 }
810
811 #[test]
812 fn signed_get_url() {
813 let credential = AwsCredential {
815 key_id: "AKIAIOSFODNN7EXAMPLE".to_string(),
816 secret_key: "wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY".to_string(),
817 token: None,
818 };
819
820 let date = DateTime::parse_from_rfc3339("2013-05-24T00:00:00Z")
821 .unwrap()
822 .with_timezone(&Utc);
823
824 let authorizer = AwsAuthorizer {
825 date: Some(date),
826 credential: &credential,
827 service: "s3",
828 region: "us-east-1",
829 token_header: None,
830 sign_payload: false,
831 };
832
833 let mut url = Url::parse("https://examplebucket.s3.amazonaws.com/test.txt").unwrap();
834 authorizer.sign(Method::GET, &mut url, Duration::from_secs(86400));
835
836 assert_eq!(
837 url,
838 Url::parse(
839 "https://examplebucket.s3.amazonaws.com/test.txt?\
840 X-Amz-Algorithm=AWS4-HMAC-SHA256&\
841 X-Amz-Credential=AKIAIOSFODNN7EXAMPLE%2F20130524%2Fus-east-1%2Fs3%2Faws4_request&\
842 X-Amz-Date=20130524T000000Z&\
843 X-Amz-Expires=86400&\
844 X-Amz-SignedHeaders=host&\
845 X-Amz-Signature=aeeed9bbccd4d02ee5c0109b86d86835f995330da4c265957d157751f604d404"
846 )
847 .unwrap()
848 );
849 }
850
851 #[test]
852 fn test_sign_port() {
853 let client = Client::new();
854
855 let credential = AwsCredential {
856 key_id: "H20ABqCkLZID4rLe".to_string(),
857 secret_key: "jMqRDgxSsBqqznfmddGdu1TmmZOJQxdM".to_string(),
858 token: None,
859 };
860
861 let date = DateTime::parse_from_rfc3339("2022-08-09T13:05:25Z")
862 .unwrap()
863 .with_timezone(&Utc);
864
865 let mut request = client
866 .request(Method::GET, "http://localhost:9000/tsm-schemas")
867 .query(&[
868 ("delimiter", "/"),
869 ("encoding-type", "url"),
870 ("list-type", "2"),
871 ("prefix", ""),
872 ])
873 .build()
874 .unwrap();
875
876 let authorizer = AwsAuthorizer {
877 date: Some(date),
878 credential: &credential,
879 service: "s3",
880 region: "us-east-1",
881 token_header: None,
882 sign_payload: true,
883 };
884
885 authorizer.authorize(&mut request, None);
886 assert_eq!(request.headers().get(&AUTHORIZATION).unwrap(), "AWS4-HMAC-SHA256 Credential=H20ABqCkLZID4rLe/20220809/us-east-1/s3/aws4_request, SignedHeaders=host;x-amz-content-sha256;x-amz-date, Signature=9ebf2f92872066c99ac94e573b4e1b80f4dbb8a32b1e8e23178318746e7d1b4d")
887 }
888
889 #[tokio::test]
890 async fn test_instance_metadata() {
891 if env::var("TEST_INTEGRATION").is_err() {
892 eprintln!("skipping AWS integration test");
893 return;
894 }
895
896 let endpoint = env::var("EC2_METADATA_ENDPOINT").unwrap();
898 let client = Client::new();
899 let retry_config = RetryConfig::default();
900
901 let resp = client
903 .request(Method::GET, format!("{endpoint}/latest/meta-data/ami-id"))
904 .send()
905 .await
906 .unwrap();
907
908 assert_eq!(
909 resp.status(),
910 StatusCode::UNAUTHORIZED,
911 "Ensure metadata endpoint is set to only allow IMDSv2"
912 );
913
914 let creds = instance_creds(&client, &retry_config, &endpoint, false)
915 .await
916 .unwrap();
917
918 let id = &creds.token.key_id;
919 let secret = &creds.token.secret_key;
920 let token = creds.token.token.as_ref().unwrap();
921
922 assert!(!id.is_empty());
923 assert!(!secret.is_empty());
924 assert!(!token.is_empty())
925 }
926
927 #[tokio::test]
928 async fn test_mock() {
929 let server = MockServer::new().await;
930
931 const IMDSV2_HEADER: &str = "X-aws-ec2-metadata-token";
932
933 let secret_access_key = "SECRET";
934 let access_key_id = "KEYID";
935 let token = "TOKEN";
936
937 let endpoint = server.url();
938 let client = Client::new();
939 let retry_config = RetryConfig::default();
940
941 server.push_fn(|req| {
943 assert_eq!(req.uri().path(), "/latest/api/token");
944 assert_eq!(req.method(), &Method::PUT);
945 Response::new("cupcakes".to_string())
946 });
947 server.push_fn(|req| {
948 assert_eq!(
949 req.uri().path(),
950 "/latest/meta-data/iam/security-credentials/"
951 );
952 assert_eq!(req.method(), &Method::GET);
953 let t = req.headers().get(IMDSV2_HEADER).unwrap().to_str().unwrap();
954 assert_eq!(t, "cupcakes");
955 Response::new("myrole".to_string())
956 });
957 server.push_fn(|req| {
958 assert_eq!(req.uri().path(), "/latest/meta-data/iam/security-credentials/myrole");
959 assert_eq!(req.method(), &Method::GET);
960 let t = req.headers().get(IMDSV2_HEADER).unwrap().to_str().unwrap();
961 assert_eq!(t, "cupcakes");
962 Response::new(r#"{"AccessKeyId":"KEYID","Code":"Success","Expiration":"2022-08-30T10:51:04Z","LastUpdated":"2022-08-30T10:21:04Z","SecretAccessKey":"SECRET","Token":"TOKEN","Type":"AWS-HMAC"}"#.to_string())
963 });
964
965 let creds = instance_creds(&client, &retry_config, endpoint, true)
966 .await
967 .unwrap();
968
969 assert_eq!(creds.token.token.as_deref().unwrap(), token);
970 assert_eq!(&creds.token.key_id, access_key_id);
971 assert_eq!(&creds.token.secret_key, secret_access_key);
972
973 server.push_fn(|req| {
975 assert_eq!(req.uri().path(), "/latest/api/token");
976 assert_eq!(req.method(), &Method::PUT);
977 Response::builder()
978 .status(StatusCode::FORBIDDEN)
979 .body(String::new())
980 .unwrap()
981 });
982 server.push_fn(|req| {
983 assert_eq!(
984 req.uri().path(),
985 "/latest/meta-data/iam/security-credentials/"
986 );
987 assert_eq!(req.method(), &Method::GET);
988 assert!(req.headers().get(IMDSV2_HEADER).is_none());
989 Response::new("myrole".to_string())
990 });
991 server.push_fn(|req| {
992 assert_eq!(req.uri().path(), "/latest/meta-data/iam/security-credentials/myrole");
993 assert_eq!(req.method(), &Method::GET);
994 assert!(req.headers().get(IMDSV2_HEADER).is_none());
995 Response::new(r#"{"AccessKeyId":"KEYID","Code":"Success","Expiration":"2022-08-30T10:51:04Z","LastUpdated":"2022-08-30T10:21:04Z","SecretAccessKey":"SECRET","Token":"TOKEN","Type":"AWS-HMAC"}"#.to_string())
996 });
997
998 let creds = instance_creds(&client, &retry_config, endpoint, true)
999 .await
1000 .unwrap();
1001
1002 assert_eq!(creds.token.token.as_deref().unwrap(), token);
1003 assert_eq!(&creds.token.key_id, access_key_id);
1004 assert_eq!(&creds.token.secret_key, secret_access_key);
1005
1006 server.push(
1008 Response::builder()
1009 .status(StatusCode::FORBIDDEN)
1010 .body(String::new())
1011 .unwrap(),
1012 );
1013
1014 instance_creds(&client, &retry_config, endpoint, false)
1016 .await
1017 .unwrap_err();
1018 }
1019}