object_store/aws/
credential.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use crate::aws::{AwsCredentialProvider, STORE, STRICT_ENCODE_SET, STRICT_PATH_ENCODE_SET};
19use crate::client::retry::RetryExt;
20use crate::client::token::{TemporaryToken, TokenCache};
21use crate::client::TokenProvider;
22use crate::util::{hex_digest, hex_encode, hmac_sha256};
23use crate::{CredentialProvider, Result, RetryConfig};
24use async_trait::async_trait;
25use bytes::Buf;
26use chrono::{DateTime, Utc};
27use hyper::header::HeaderName;
28use percent_encoding::utf8_percent_encode;
29use reqwest::header::{HeaderMap, HeaderValue, AUTHORIZATION};
30use reqwest::{Client, Method, Request, RequestBuilder, StatusCode};
31use serde::Deserialize;
32use snafu::{ResultExt, Snafu};
33use std::collections::BTreeMap;
34use std::sync::Arc;
35use std::time::{Duration, Instant};
36use tracing::warn;
37use url::Url;
38
39#[derive(Debug, Snafu)]
40#[allow(clippy::enum_variant_names)]
41enum Error {
42    #[snafu(display("Error performing CreateSession request: {source}"))]
43    CreateSessionRequest { source: crate::client::retry::Error },
44
45    #[snafu(display("Error getting CreateSession response: {source}"))]
46    CreateSessionResponse { source: reqwest::Error },
47
48    #[snafu(display("Invalid CreateSessionOutput response: {source}"))]
49    CreateSessionOutput { source: quick_xml::DeError },
50}
51
52impl From<Error> for crate::Error {
53    fn from(value: Error) -> Self {
54        Self::Generic {
55            store: STORE,
56            source: Box::new(value),
57        }
58    }
59}
60
61type StdError = Box<dyn std::error::Error + Send + Sync>;
62
63/// SHA256 hash of empty string
64static EMPTY_SHA256_HASH: &str = "e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855";
65static UNSIGNED_PAYLOAD: &str = "UNSIGNED-PAYLOAD";
66static STREAMING_PAYLOAD: &str = "STREAMING-AWS4-HMAC-SHA256-PAYLOAD";
67
68/// A set of AWS security credentials
69#[derive(Debug, Eq, PartialEq)]
70pub struct AwsCredential {
71    /// AWS_ACCESS_KEY_ID
72    pub key_id: String,
73    /// AWS_SECRET_ACCESS_KEY
74    pub secret_key: String,
75    /// AWS_SESSION_TOKEN
76    pub token: Option<String>,
77}
78
79impl AwsCredential {
80    /// Signs a string
81    ///
82    /// <https://docs.aws.amazon.com/general/latest/gr/sigv4-calculate-signature.html>
83    fn sign(&self, to_sign: &str, date: DateTime<Utc>, region: &str, service: &str) -> String {
84        let date_string = date.format("%Y%m%d").to_string();
85        let date_hmac = hmac_sha256(format!("AWS4{}", self.secret_key), date_string);
86        let region_hmac = hmac_sha256(date_hmac, region);
87        let service_hmac = hmac_sha256(region_hmac, service);
88        let signing_hmac = hmac_sha256(service_hmac, b"aws4_request");
89        hex_encode(hmac_sha256(signing_hmac, to_sign).as_ref())
90    }
91}
92
93/// Authorize a [`Request`] with an [`AwsCredential`] using [AWS SigV4]
94///
95/// [AWS SigV4]: https://docs.aws.amazon.com/general/latest/gr/sigv4-calculate-signature.html
96#[derive(Debug)]
97pub struct AwsAuthorizer<'a> {
98    date: Option<DateTime<Utc>>,
99    credential: &'a AwsCredential,
100    service: &'a str,
101    region: &'a str,
102    token_header: Option<HeaderName>,
103    sign_payload: bool,
104}
105
106static DATE_HEADER: HeaderName = HeaderName::from_static("x-amz-date");
107static HASH_HEADER: HeaderName = HeaderName::from_static("x-amz-content-sha256");
108static TOKEN_HEADER: HeaderName = HeaderName::from_static("x-amz-security-token");
109const ALGORITHM: &str = "AWS4-HMAC-SHA256";
110
111impl<'a> AwsAuthorizer<'a> {
112    /// Create a new [`AwsAuthorizer`]
113    pub fn new(credential: &'a AwsCredential, service: &'a str, region: &'a str) -> Self {
114        Self {
115            credential,
116            service,
117            region,
118            date: None,
119            sign_payload: true,
120            token_header: None,
121        }
122    }
123
124    /// Controls whether this [`AwsAuthorizer`] will attempt to sign the request payload,
125    /// the default is `true`
126    pub fn with_sign_payload(mut self, signed: bool) -> Self {
127        self.sign_payload = signed;
128        self
129    }
130
131    /// Overrides the header name for security tokens, defaults to `x-amz-security-token`
132    pub(crate) fn with_token_header(mut self, header: HeaderName) -> Self {
133        self.token_header = Some(header);
134        self
135    }
136
137    /// Authorize `request` with an optional pre-calculated SHA256 digest by attaching
138    /// the relevant [AWS SigV4] headers
139    ///
140    /// # Payload Signature
141    ///
142    /// AWS SigV4 requests must contain the `x-amz-content-sha256` header, it is set as follows:
143    ///
144    /// * If not configured to sign payloads, it is set to `UNSIGNED-PAYLOAD`
145    /// * If a `pre_calculated_digest` is provided, it is set to the hex encoding of it
146    /// * If it is a streaming request, it is set to `STREAMING-AWS4-HMAC-SHA256-PAYLOAD`
147    /// * Otherwise it is set to the hex encoded SHA256 of the request body
148    ///
149    /// [AWS SigV4]: https://docs.aws.amazon.com/IAM/latest/UserGuide/create-signed-request.html
150    pub fn authorize(&self, request: &mut Request, pre_calculated_digest: Option<&[u8]>) {
151        if let Some(ref token) = self.credential.token {
152            let token_val = HeaderValue::from_str(token).unwrap();
153            let header = self.token_header.as_ref().unwrap_or(&TOKEN_HEADER);
154            request.headers_mut().insert(header, token_val);
155        }
156
157        let host = &request.url()[url::Position::BeforeHost..url::Position::AfterPort];
158        let host_val = HeaderValue::from_str(host).unwrap();
159        request.headers_mut().insert("host", host_val);
160
161        let date = self.date.unwrap_or_else(Utc::now);
162        let date_str = date.format("%Y%m%dT%H%M%SZ").to_string();
163        let date_val = HeaderValue::from_str(&date_str).unwrap();
164        request.headers_mut().insert(&DATE_HEADER, date_val);
165
166        let digest = match self.sign_payload {
167            false => UNSIGNED_PAYLOAD.to_string(),
168            true => match pre_calculated_digest {
169                Some(digest) => hex_encode(digest),
170                None => match request.body() {
171                    None => EMPTY_SHA256_HASH.to_string(),
172                    Some(body) => match body.as_bytes() {
173                        Some(bytes) => hex_digest(bytes),
174                        None => STREAMING_PAYLOAD.to_string(),
175                    },
176                },
177            },
178        };
179
180        let header_digest = HeaderValue::from_str(&digest).unwrap();
181        request.headers_mut().insert(&HASH_HEADER, header_digest);
182
183        let (signed_headers, canonical_headers) = canonicalize_headers(request.headers());
184
185        let scope = self.scope(date);
186
187        let string_to_sign = self.string_to_sign(
188            date,
189            &scope,
190            request.method(),
191            request.url(),
192            &canonical_headers,
193            &signed_headers,
194            &digest,
195        );
196
197        // sign the string
198        let signature = self
199            .credential
200            .sign(&string_to_sign, date, self.region, self.service);
201
202        // build the actual auth header
203        let authorisation = format!(
204            "{} Credential={}/{}, SignedHeaders={}, Signature={}",
205            ALGORITHM, self.credential.key_id, scope, signed_headers, signature
206        );
207
208        let authorization_val = HeaderValue::from_str(&authorisation).unwrap();
209        request
210            .headers_mut()
211            .insert(&AUTHORIZATION, authorization_val);
212    }
213
214    pub(crate) fn sign(&self, method: Method, url: &mut Url, expires_in: Duration) {
215        let date = self.date.unwrap_or_else(Utc::now);
216        let scope = self.scope(date);
217
218        // https://docs.aws.amazon.com/AmazonS3/latest/API/sigv4-query-string-auth.html
219        url.query_pairs_mut()
220            .append_pair("X-Amz-Algorithm", ALGORITHM)
221            .append_pair(
222                "X-Amz-Credential",
223                &format!("{}/{}", self.credential.key_id, scope),
224            )
225            .append_pair("X-Amz-Date", &date.format("%Y%m%dT%H%M%SZ").to_string())
226            .append_pair("X-Amz-Expires", &expires_in.as_secs().to_string())
227            .append_pair("X-Amz-SignedHeaders", "host");
228
229        // For S3, you must include the X-Amz-Security-Token query parameter in the URL if
230        // using credentials sourced from the STS service.
231        if let Some(ref token) = self.credential.token {
232            url.query_pairs_mut()
233                .append_pair("X-Amz-Security-Token", token);
234        }
235
236        // We don't have a payload; the user is going to send the payload directly themselves.
237        let digest = UNSIGNED_PAYLOAD;
238
239        let host = &url[url::Position::BeforeHost..url::Position::AfterPort].to_string();
240        let mut headers = HeaderMap::new();
241        let host_val = HeaderValue::from_str(host).unwrap();
242        headers.insert("host", host_val);
243
244        let (signed_headers, canonical_headers) = canonicalize_headers(&headers);
245
246        let string_to_sign = self.string_to_sign(
247            date,
248            &scope,
249            &method,
250            url,
251            &canonical_headers,
252            &signed_headers,
253            digest,
254        );
255
256        let signature = self
257            .credential
258            .sign(&string_to_sign, date, self.region, self.service);
259
260        url.query_pairs_mut()
261            .append_pair("X-Amz-Signature", &signature);
262    }
263
264    #[allow(clippy::too_many_arguments)]
265    fn string_to_sign(
266        &self,
267        date: DateTime<Utc>,
268        scope: &str,
269        request_method: &Method,
270        url: &Url,
271        canonical_headers: &str,
272        signed_headers: &str,
273        digest: &str,
274    ) -> String {
275        // Each path segment must be URI-encoded twice (except for Amazon S3 which only gets
276        // URI-encoded once).
277        // see https://docs.aws.amazon.com/general/latest/gr/sigv4-create-canonical-request.html
278        let canonical_uri = match self.service {
279            "s3" => url.path().to_string(),
280            _ => utf8_percent_encode(url.path(), &STRICT_PATH_ENCODE_SET).to_string(),
281        };
282
283        let canonical_query = canonicalize_query(url);
284
285        // https://docs.aws.amazon.com/general/latest/gr/sigv4-create-canonical-request.html
286        let canonical_request = format!(
287            "{}\n{}\n{}\n{}\n{}\n{}",
288            request_method.as_str(),
289            canonical_uri,
290            canonical_query,
291            canonical_headers,
292            signed_headers,
293            digest
294        );
295
296        let hashed_canonical_request = hex_digest(canonical_request.as_bytes());
297
298        format!(
299            "{}\n{}\n{}\n{}",
300            ALGORITHM,
301            date.format("%Y%m%dT%H%M%SZ"),
302            scope,
303            hashed_canonical_request
304        )
305    }
306
307    fn scope(&self, date: DateTime<Utc>) -> String {
308        format!(
309            "{}/{}/{}/aws4_request",
310            date.format("%Y%m%d"),
311            self.region,
312            self.service
313        )
314    }
315}
316
317pub trait CredentialExt {
318    /// Sign a request <https://docs.aws.amazon.com/general/latest/gr/sigv4_signing.html>
319    fn with_aws_sigv4(
320        self,
321        authorizer: Option<AwsAuthorizer<'_>>,
322        payload_sha256: Option<&[u8]>,
323    ) -> Self;
324}
325
326impl CredentialExt for RequestBuilder {
327    fn with_aws_sigv4(
328        self,
329        authorizer: Option<AwsAuthorizer<'_>>,
330        payload_sha256: Option<&[u8]>,
331    ) -> Self {
332        match authorizer {
333            Some(authorizer) => {
334                let (client, request) = self.build_split();
335                let mut request = request.expect("request valid");
336                authorizer.authorize(&mut request, payload_sha256);
337
338                Self::from_parts(client, request)
339            }
340            None => self,
341        }
342    }
343}
344
345/// Canonicalizes query parameters into the AWS canonical form
346///
347/// <https://docs.aws.amazon.com/general/latest/gr/sigv4-create-canonical-request.html>
348fn canonicalize_query(url: &Url) -> String {
349    use std::fmt::Write;
350
351    let capacity = match url.query() {
352        Some(q) if !q.is_empty() => q.len(),
353        _ => return String::new(),
354    };
355    let mut encoded = String::with_capacity(capacity + 1);
356
357    let mut headers = url.query_pairs().collect::<Vec<_>>();
358    headers.sort_unstable_by(|(a, _), (b, _)| a.cmp(b));
359
360    let mut first = true;
361    for (k, v) in headers {
362        if !first {
363            encoded.push('&');
364        }
365        first = false;
366        let _ = write!(
367            encoded,
368            "{}={}",
369            utf8_percent_encode(k.as_ref(), &STRICT_ENCODE_SET),
370            utf8_percent_encode(v.as_ref(), &STRICT_ENCODE_SET)
371        );
372    }
373    encoded
374}
375
376/// Canonicalizes headers into the AWS Canonical Form.
377///
378/// <https://docs.aws.amazon.com/general/latest/gr/sigv4-create-canonical-request.html>
379fn canonicalize_headers(header_map: &HeaderMap) -> (String, String) {
380    let mut headers = BTreeMap::<&str, Vec<&str>>::new();
381    let mut value_count = 0;
382    let mut value_bytes = 0;
383    let mut key_bytes = 0;
384
385    for (key, value) in header_map {
386        let key = key.as_str();
387        if ["authorization", "content-length", "user-agent"].contains(&key) {
388            continue;
389        }
390
391        let value = std::str::from_utf8(value.as_bytes()).unwrap();
392        key_bytes += key.len();
393        value_bytes += value.len();
394        value_count += 1;
395        headers.entry(key).or_default().push(value);
396    }
397
398    let mut signed_headers = String::with_capacity(key_bytes + headers.len());
399    let mut canonical_headers =
400        String::with_capacity(key_bytes + value_bytes + headers.len() + value_count);
401
402    for (header_idx, (name, values)) in headers.into_iter().enumerate() {
403        if header_idx != 0 {
404            signed_headers.push(';');
405        }
406
407        signed_headers.push_str(name);
408        canonical_headers.push_str(name);
409        canonical_headers.push(':');
410        for (value_idx, value) in values.into_iter().enumerate() {
411            if value_idx != 0 {
412                canonical_headers.push(',');
413            }
414            canonical_headers.push_str(value.trim());
415        }
416        canonical_headers.push('\n');
417    }
418
419    (signed_headers, canonical_headers)
420}
421
422/// Credentials sourced from the instance metadata service
423///
424/// <https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/configuring-instance-metadata-service.html>
425#[derive(Debug)]
426pub struct InstanceCredentialProvider {
427    pub imdsv1_fallback: bool,
428    pub metadata_endpoint: String,
429}
430
431#[async_trait]
432impl TokenProvider for InstanceCredentialProvider {
433    type Credential = AwsCredential;
434
435    async fn fetch_token(
436        &self,
437        client: &Client,
438        retry: &RetryConfig,
439    ) -> Result<TemporaryToken<Arc<AwsCredential>>> {
440        instance_creds(client, retry, &self.metadata_endpoint, self.imdsv1_fallback)
441            .await
442            .map_err(|source| crate::Error::Generic {
443                store: STORE,
444                source,
445            })
446    }
447}
448
449/// Credentials sourced using AssumeRoleWithWebIdentity
450///
451/// <https://docs.aws.amazon.com/eks/latest/userguide/iam-roles-for-service-accounts-technical-overview.html>
452#[derive(Debug)]
453pub struct WebIdentityProvider {
454    pub token_path: String,
455    pub role_arn: String,
456    pub session_name: String,
457    pub endpoint: String,
458}
459
460#[async_trait]
461impl TokenProvider for WebIdentityProvider {
462    type Credential = AwsCredential;
463
464    async fn fetch_token(
465        &self,
466        client: &Client,
467        retry: &RetryConfig,
468    ) -> Result<TemporaryToken<Arc<AwsCredential>>> {
469        web_identity(
470            client,
471            retry,
472            &self.token_path,
473            &self.role_arn,
474            &self.session_name,
475            &self.endpoint,
476        )
477        .await
478        .map_err(|source| crate::Error::Generic {
479            store: STORE,
480            source,
481        })
482    }
483}
484
485#[derive(Debug, Deserialize)]
486#[serde(rename_all = "PascalCase")]
487struct InstanceCredentials {
488    access_key_id: String,
489    secret_access_key: String,
490    token: String,
491    expiration: DateTime<Utc>,
492}
493
494impl From<InstanceCredentials> for AwsCredential {
495    fn from(s: InstanceCredentials) -> Self {
496        Self {
497            key_id: s.access_key_id,
498            secret_key: s.secret_access_key,
499            token: Some(s.token),
500        }
501    }
502}
503
504/// <https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/iam-roles-for-amazon-ec2.html#instance-metadata-security-credentials>
505async fn instance_creds(
506    client: &Client,
507    retry_config: &RetryConfig,
508    endpoint: &str,
509    imdsv1_fallback: bool,
510) -> Result<TemporaryToken<Arc<AwsCredential>>, StdError> {
511    const CREDENTIALS_PATH: &str = "latest/meta-data/iam/security-credentials";
512    const AWS_EC2_METADATA_TOKEN_HEADER: &str = "X-aws-ec2-metadata-token";
513
514    let token_url = format!("{endpoint}/latest/api/token");
515
516    let token_result = client
517        .request(Method::PUT, token_url)
518        .header("X-aws-ec2-metadata-token-ttl-seconds", "600") // 10 minute TTL
519        .retryable(retry_config)
520        .idempotent(true)
521        .send()
522        .await;
523
524    let token = match token_result {
525        Ok(t) => Some(t.text().await?),
526        Err(e) if imdsv1_fallback && matches!(e.status(), Some(StatusCode::FORBIDDEN)) => {
527            warn!("received 403 from metadata endpoint, falling back to IMDSv1");
528            None
529        }
530        Err(e) => return Err(e.into()),
531    };
532
533    let role_url = format!("{endpoint}/{CREDENTIALS_PATH}/");
534    let mut role_request = client.request(Method::GET, role_url);
535
536    if let Some(token) = &token {
537        role_request = role_request.header(AWS_EC2_METADATA_TOKEN_HEADER, token);
538    }
539
540    let role = role_request.send_retry(retry_config).await?.text().await?;
541
542    let creds_url = format!("{endpoint}/{CREDENTIALS_PATH}/{role}");
543    let mut creds_request = client.request(Method::GET, creds_url);
544    if let Some(token) = &token {
545        creds_request = creds_request.header(AWS_EC2_METADATA_TOKEN_HEADER, token);
546    }
547
548    let creds: InstanceCredentials = creds_request.send_retry(retry_config).await?.json().await?;
549
550    let now = Utc::now();
551    let ttl = (creds.expiration - now).to_std().unwrap_or_default();
552    Ok(TemporaryToken {
553        token: Arc::new(creds.into()),
554        expiry: Some(Instant::now() + ttl),
555    })
556}
557
558#[derive(Debug, Deserialize)]
559#[serde(rename_all = "PascalCase")]
560struct AssumeRoleResponse {
561    assume_role_with_web_identity_result: AssumeRoleResult,
562}
563
564#[derive(Debug, Deserialize)]
565#[serde(rename_all = "PascalCase")]
566struct AssumeRoleResult {
567    credentials: SessionCredentials,
568}
569
570#[derive(Debug, Deserialize)]
571#[serde(rename_all = "PascalCase")]
572struct SessionCredentials {
573    session_token: String,
574    secret_access_key: String,
575    access_key_id: String,
576    expiration: DateTime<Utc>,
577}
578
579impl From<SessionCredentials> for AwsCredential {
580    fn from(s: SessionCredentials) -> Self {
581        Self {
582            key_id: s.access_key_id,
583            secret_key: s.secret_access_key,
584            token: Some(s.session_token),
585        }
586    }
587}
588
589/// <https://docs.aws.amazon.com/eks/latest/userguide/iam-roles-for-service-accounts-technical-overview.html>
590async fn web_identity(
591    client: &Client,
592    retry_config: &RetryConfig,
593    token_path: &str,
594    role_arn: &str,
595    session_name: &str,
596    endpoint: &str,
597) -> Result<TemporaryToken<Arc<AwsCredential>>, StdError> {
598    let token = std::fs::read_to_string(token_path)
599        .map_err(|e| format!("Failed to read token file '{token_path}': {e}"))?;
600
601    let bytes = client
602        .request(Method::POST, endpoint)
603        .query(&[
604            ("Action", "AssumeRoleWithWebIdentity"),
605            ("DurationSeconds", "3600"),
606            ("RoleArn", role_arn),
607            ("RoleSessionName", session_name),
608            ("Version", "2011-06-15"),
609            ("WebIdentityToken", &token),
610        ])
611        .retryable(retry_config)
612        .idempotent(true)
613        .sensitive(true)
614        .send()
615        .await?
616        .bytes()
617        .await?;
618
619    let resp: AssumeRoleResponse = quick_xml::de::from_reader(bytes.reader())
620        .map_err(|e| format!("Invalid AssumeRoleWithWebIdentity response: {e}"))?;
621
622    let creds = resp.assume_role_with_web_identity_result.credentials;
623    let now = Utc::now();
624    let ttl = (creds.expiration - now).to_std().unwrap_or_default();
625
626    Ok(TemporaryToken {
627        token: Arc::new(creds.into()),
628        expiry: Some(Instant::now() + ttl),
629    })
630}
631
632/// Credentials sourced from a task IAM role
633///
634/// <https://docs.aws.amazon.com/AmazonECS/latest/developerguide/task-iam-roles.html>
635#[derive(Debug)]
636pub struct TaskCredentialProvider {
637    pub url: String,
638    pub retry: RetryConfig,
639    pub client: Client,
640    pub cache: TokenCache<Arc<AwsCredential>>,
641}
642
643#[async_trait]
644impl CredentialProvider for TaskCredentialProvider {
645    type Credential = AwsCredential;
646
647    async fn get_credential(&self) -> Result<Arc<AwsCredential>> {
648        self.cache
649            .get_or_insert_with(|| task_credential(&self.client, &self.retry, &self.url))
650            .await
651            .map_err(|source| crate::Error::Generic {
652                store: STORE,
653                source,
654            })
655    }
656}
657
658/// <https://docs.aws.amazon.com/AmazonECS/latest/developerguide/task-iam-roles.html>
659async fn task_credential(
660    client: &Client,
661    retry: &RetryConfig,
662    url: &str,
663) -> Result<TemporaryToken<Arc<AwsCredential>>, StdError> {
664    let creds: InstanceCredentials = client.get(url).send_retry(retry).await?.json().await?;
665
666    let now = Utc::now();
667    let ttl = (creds.expiration - now).to_std().unwrap_or_default();
668    Ok(TemporaryToken {
669        token: Arc::new(creds.into()),
670        expiry: Some(Instant::now() + ttl),
671    })
672}
673
674/// A session provider as used by S3 Express One Zone
675///
676/// <https://docs.aws.amazon.com/AmazonS3/latest/API/API_CreateSession.html>
677#[derive(Debug)]
678pub struct SessionProvider {
679    pub endpoint: String,
680    pub region: String,
681    pub credentials: AwsCredentialProvider,
682}
683
684#[async_trait]
685impl TokenProvider for SessionProvider {
686    type Credential = AwsCredential;
687
688    async fn fetch_token(
689        &self,
690        client: &Client,
691        retry: &RetryConfig,
692    ) -> Result<TemporaryToken<Arc<Self::Credential>>> {
693        let creds = self.credentials.get_credential().await?;
694        let authorizer = AwsAuthorizer::new(&creds, "s3", &self.region);
695
696        let bytes = client
697            .get(format!("{}?session", self.endpoint))
698            .with_aws_sigv4(Some(authorizer), None)
699            .send_retry(retry)
700            .await
701            .context(CreateSessionRequestSnafu)?
702            .bytes()
703            .await
704            .context(CreateSessionResponseSnafu)?;
705
706        let resp: CreateSessionOutput =
707            quick_xml::de::from_reader(bytes.reader()).context(CreateSessionOutputSnafu)?;
708
709        let creds = resp.credentials;
710        Ok(TemporaryToken {
711            token: Arc::new(creds.into()),
712            // Credentials last 5 minutes - https://docs.aws.amazon.com/AmazonS3/latest/API/API_CreateSession.html
713            expiry: Some(Instant::now() + Duration::from_secs(5 * 60)),
714        })
715    }
716}
717
718#[derive(Debug, Deserialize)]
719#[serde(rename_all = "PascalCase")]
720struct CreateSessionOutput {
721    credentials: SessionCredentials,
722}
723
724#[cfg(test)]
725mod tests {
726    use super::*;
727    use crate::client::mock_server::MockServer;
728    use hyper::Response;
729    use reqwest::{Client, Method};
730    use std::env;
731
732    // Test generated using https://docs.aws.amazon.com/general/latest/gr/sigv4-signed-request-examples.html
733    #[test]
734    fn test_sign_with_signed_payload() {
735        let client = Client::new();
736
737        // Test credentials from https://docs.aws.amazon.com/AmazonS3/latest/userguide/RESTAuthentication.html
738        let credential = AwsCredential {
739            key_id: "AKIAIOSFODNN7EXAMPLE".to_string(),
740            secret_key: "wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY".to_string(),
741            token: None,
742        };
743
744        // method = 'GET'
745        // service = 'ec2'
746        // host = 'ec2.amazonaws.com'
747        // region = 'us-east-1'
748        // endpoint = 'https://ec2.amazonaws.com'
749        // request_parameters = ''
750        let date = DateTime::parse_from_rfc3339("2022-08-06T18:01:34Z")
751            .unwrap()
752            .with_timezone(&Utc);
753
754        let mut request = client
755            .request(Method::GET, "https://ec2.amazon.com/")
756            .build()
757            .unwrap();
758
759        let signer = AwsAuthorizer {
760            date: Some(date),
761            credential: &credential,
762            service: "ec2",
763            region: "us-east-1",
764            sign_payload: true,
765            token_header: None,
766        };
767
768        signer.authorize(&mut request, None);
769        assert_eq!(request.headers().get(&AUTHORIZATION).unwrap(), "AWS4-HMAC-SHA256 Credential=AKIAIOSFODNN7EXAMPLE/20220806/us-east-1/ec2/aws4_request, SignedHeaders=host;x-amz-content-sha256;x-amz-date, Signature=a3c787a7ed37f7fdfbfd2d7056a3d7c9d85e6d52a2bfbec73793c0be6e7862d4")
770    }
771
772    #[test]
773    fn test_sign_with_unsigned_payload() {
774        let client = Client::new();
775
776        // Test credentials from https://docs.aws.amazon.com/AmazonS3/latest/userguide/RESTAuthentication.html
777        let credential = AwsCredential {
778            key_id: "AKIAIOSFODNN7EXAMPLE".to_string(),
779            secret_key: "wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY".to_string(),
780            token: None,
781        };
782
783        // method = 'GET'
784        // service = 'ec2'
785        // host = 'ec2.amazonaws.com'
786        // region = 'us-east-1'
787        // endpoint = 'https://ec2.amazonaws.com'
788        // request_parameters = ''
789        let date = DateTime::parse_from_rfc3339("2022-08-06T18:01:34Z")
790            .unwrap()
791            .with_timezone(&Utc);
792
793        let mut request = client
794            .request(Method::GET, "https://ec2.amazon.com/")
795            .build()
796            .unwrap();
797
798        let authorizer = AwsAuthorizer {
799            date: Some(date),
800            credential: &credential,
801            service: "ec2",
802            region: "us-east-1",
803            token_header: None,
804            sign_payload: false,
805        };
806
807        authorizer.authorize(&mut request, None);
808        assert_eq!(request.headers().get(&AUTHORIZATION).unwrap(), "AWS4-HMAC-SHA256 Credential=AKIAIOSFODNN7EXAMPLE/20220806/us-east-1/ec2/aws4_request, SignedHeaders=host;x-amz-content-sha256;x-amz-date, Signature=653c3d8ea261fd826207df58bc2bb69fbb5003e9eb3c0ef06e4a51f2a81d8699");
809    }
810
811    #[test]
812    fn signed_get_url() {
813        // Values from https://docs.aws.amazon.com/AmazonS3/latest/API/sigv4-query-string-auth.html
814        let credential = AwsCredential {
815            key_id: "AKIAIOSFODNN7EXAMPLE".to_string(),
816            secret_key: "wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY".to_string(),
817            token: None,
818        };
819
820        let date = DateTime::parse_from_rfc3339("2013-05-24T00:00:00Z")
821            .unwrap()
822            .with_timezone(&Utc);
823
824        let authorizer = AwsAuthorizer {
825            date: Some(date),
826            credential: &credential,
827            service: "s3",
828            region: "us-east-1",
829            token_header: None,
830            sign_payload: false,
831        };
832
833        let mut url = Url::parse("https://examplebucket.s3.amazonaws.com/test.txt").unwrap();
834        authorizer.sign(Method::GET, &mut url, Duration::from_secs(86400));
835
836        assert_eq!(
837            url,
838            Url::parse(
839                "https://examplebucket.s3.amazonaws.com/test.txt?\
840                X-Amz-Algorithm=AWS4-HMAC-SHA256&\
841                X-Amz-Credential=AKIAIOSFODNN7EXAMPLE%2F20130524%2Fus-east-1%2Fs3%2Faws4_request&\
842                X-Amz-Date=20130524T000000Z&\
843                X-Amz-Expires=86400&\
844                X-Amz-SignedHeaders=host&\
845                X-Amz-Signature=aeeed9bbccd4d02ee5c0109b86d86835f995330da4c265957d157751f604d404"
846            )
847            .unwrap()
848        );
849    }
850
851    #[test]
852    fn test_sign_port() {
853        let client = Client::new();
854
855        let credential = AwsCredential {
856            key_id: "H20ABqCkLZID4rLe".to_string(),
857            secret_key: "jMqRDgxSsBqqznfmddGdu1TmmZOJQxdM".to_string(),
858            token: None,
859        };
860
861        let date = DateTime::parse_from_rfc3339("2022-08-09T13:05:25Z")
862            .unwrap()
863            .with_timezone(&Utc);
864
865        let mut request = client
866            .request(Method::GET, "http://localhost:9000/tsm-schemas")
867            .query(&[
868                ("delimiter", "/"),
869                ("encoding-type", "url"),
870                ("list-type", "2"),
871                ("prefix", ""),
872            ])
873            .build()
874            .unwrap();
875
876        let authorizer = AwsAuthorizer {
877            date: Some(date),
878            credential: &credential,
879            service: "s3",
880            region: "us-east-1",
881            token_header: None,
882            sign_payload: true,
883        };
884
885        authorizer.authorize(&mut request, None);
886        assert_eq!(request.headers().get(&AUTHORIZATION).unwrap(), "AWS4-HMAC-SHA256 Credential=H20ABqCkLZID4rLe/20220809/us-east-1/s3/aws4_request, SignedHeaders=host;x-amz-content-sha256;x-amz-date, Signature=9ebf2f92872066c99ac94e573b4e1b80f4dbb8a32b1e8e23178318746e7d1b4d")
887    }
888
889    #[tokio::test]
890    async fn test_instance_metadata() {
891        if env::var("TEST_INTEGRATION").is_err() {
892            eprintln!("skipping AWS integration test");
893            return;
894        }
895
896        // For example https://github.com/aws/amazon-ec2-metadata-mock
897        let endpoint = env::var("EC2_METADATA_ENDPOINT").unwrap();
898        let client = Client::new();
899        let retry_config = RetryConfig::default();
900
901        // Verify only allows IMDSv2
902        let resp = client
903            .request(Method::GET, format!("{endpoint}/latest/meta-data/ami-id"))
904            .send()
905            .await
906            .unwrap();
907
908        assert_eq!(
909            resp.status(),
910            StatusCode::UNAUTHORIZED,
911            "Ensure metadata endpoint is set to only allow IMDSv2"
912        );
913
914        let creds = instance_creds(&client, &retry_config, &endpoint, false)
915            .await
916            .unwrap();
917
918        let id = &creds.token.key_id;
919        let secret = &creds.token.secret_key;
920        let token = creds.token.token.as_ref().unwrap();
921
922        assert!(!id.is_empty());
923        assert!(!secret.is_empty());
924        assert!(!token.is_empty())
925    }
926
927    #[tokio::test]
928    async fn test_mock() {
929        let server = MockServer::new().await;
930
931        const IMDSV2_HEADER: &str = "X-aws-ec2-metadata-token";
932
933        let secret_access_key = "SECRET";
934        let access_key_id = "KEYID";
935        let token = "TOKEN";
936
937        let endpoint = server.url();
938        let client = Client::new();
939        let retry_config = RetryConfig::default();
940
941        // Test IMDSv2
942        server.push_fn(|req| {
943            assert_eq!(req.uri().path(), "/latest/api/token");
944            assert_eq!(req.method(), &Method::PUT);
945            Response::new("cupcakes".to_string())
946        });
947        server.push_fn(|req| {
948            assert_eq!(
949                req.uri().path(),
950                "/latest/meta-data/iam/security-credentials/"
951            );
952            assert_eq!(req.method(), &Method::GET);
953            let t = req.headers().get(IMDSV2_HEADER).unwrap().to_str().unwrap();
954            assert_eq!(t, "cupcakes");
955            Response::new("myrole".to_string())
956        });
957        server.push_fn(|req| {
958            assert_eq!(req.uri().path(), "/latest/meta-data/iam/security-credentials/myrole");
959            assert_eq!(req.method(), &Method::GET);
960            let t = req.headers().get(IMDSV2_HEADER).unwrap().to_str().unwrap();
961            assert_eq!(t, "cupcakes");
962            Response::new(r#"{"AccessKeyId":"KEYID","Code":"Success","Expiration":"2022-08-30T10:51:04Z","LastUpdated":"2022-08-30T10:21:04Z","SecretAccessKey":"SECRET","Token":"TOKEN","Type":"AWS-HMAC"}"#.to_string())
963        });
964
965        let creds = instance_creds(&client, &retry_config, endpoint, true)
966            .await
967            .unwrap();
968
969        assert_eq!(creds.token.token.as_deref().unwrap(), token);
970        assert_eq!(&creds.token.key_id, access_key_id);
971        assert_eq!(&creds.token.secret_key, secret_access_key);
972
973        // Test IMDSv1 fallback
974        server.push_fn(|req| {
975            assert_eq!(req.uri().path(), "/latest/api/token");
976            assert_eq!(req.method(), &Method::PUT);
977            Response::builder()
978                .status(StatusCode::FORBIDDEN)
979                .body(String::new())
980                .unwrap()
981        });
982        server.push_fn(|req| {
983            assert_eq!(
984                req.uri().path(),
985                "/latest/meta-data/iam/security-credentials/"
986            );
987            assert_eq!(req.method(), &Method::GET);
988            assert!(req.headers().get(IMDSV2_HEADER).is_none());
989            Response::new("myrole".to_string())
990        });
991        server.push_fn(|req| {
992            assert_eq!(req.uri().path(), "/latest/meta-data/iam/security-credentials/myrole");
993            assert_eq!(req.method(), &Method::GET);
994            assert!(req.headers().get(IMDSV2_HEADER).is_none());
995            Response::new(r#"{"AccessKeyId":"KEYID","Code":"Success","Expiration":"2022-08-30T10:51:04Z","LastUpdated":"2022-08-30T10:21:04Z","SecretAccessKey":"SECRET","Token":"TOKEN","Type":"AWS-HMAC"}"#.to_string())
996        });
997
998        let creds = instance_creds(&client, &retry_config, endpoint, true)
999            .await
1000            .unwrap();
1001
1002        assert_eq!(creds.token.token.as_deref().unwrap(), token);
1003        assert_eq!(&creds.token.key_id, access_key_id);
1004        assert_eq!(&creds.token.secret_key, secret_access_key);
1005
1006        // Test IMDSv1 fallback disabled
1007        server.push(
1008            Response::builder()
1009                .status(StatusCode::FORBIDDEN)
1010                .body(String::new())
1011                .unwrap(),
1012        );
1013
1014        // Should fail
1015        instance_creds(&client, &retry_config, endpoint, false)
1016            .await
1017            .unwrap_err();
1018    }
1019}