object_store/azure/
credential.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use crate::azure::STORE;
19use crate::client::retry::RetryExt;
20use crate::client::token::{TemporaryToken, TokenCache};
21use crate::client::{CredentialProvider, TokenProvider};
22use crate::util::hmac_sha256;
23use crate::RetryConfig;
24use async_trait::async_trait;
25use base64::prelude::BASE64_STANDARD;
26use base64::Engine;
27use chrono::{DateTime, SecondsFormat, Utc};
28use reqwest::header::{
29    HeaderMap, HeaderName, HeaderValue, ACCEPT, AUTHORIZATION, CONTENT_ENCODING, CONTENT_LANGUAGE,
30    CONTENT_LENGTH, CONTENT_TYPE, DATE, IF_MATCH, IF_MODIFIED_SINCE, IF_NONE_MATCH,
31    IF_UNMODIFIED_SINCE, RANGE,
32};
33use reqwest::{Client, Method, Request, RequestBuilder};
34use serde::Deserialize;
35use snafu::{ResultExt, Snafu};
36use std::borrow::Cow;
37use std::collections::HashMap;
38use std::fmt::Debug;
39use std::ops::Deref;
40use std::process::Command;
41use std::str;
42use std::sync::Arc;
43use std::time::{Duration, Instant, SystemTime};
44use url::Url;
45
46use super::client::UserDelegationKey;
47
48static AZURE_VERSION: HeaderValue = HeaderValue::from_static("2023-11-03");
49static VERSION: HeaderName = HeaderName::from_static("x-ms-version");
50pub(crate) static BLOB_TYPE: HeaderName = HeaderName::from_static("x-ms-blob-type");
51pub(crate) static DELETE_SNAPSHOTS: HeaderName = HeaderName::from_static("x-ms-delete-snapshots");
52pub(crate) static COPY_SOURCE: HeaderName = HeaderName::from_static("x-ms-copy-source");
53static CONTENT_MD5: HeaderName = HeaderName::from_static("content-md5");
54pub(crate) const RFC1123_FMT: &str = "%a, %d %h %Y %T GMT";
55const CONTENT_TYPE_JSON: &str = "application/json";
56const MSI_SECRET_ENV_KEY: &str = "IDENTITY_HEADER";
57const MSI_API_VERSION: &str = "2019-08-01";
58
59/// OIDC scope used when interacting with OAuth2 APIs
60///
61/// <https://learn.microsoft.com/en-us/azure/active-directory/develop/scopes-oidc#the-default-scope>
62const AZURE_STORAGE_SCOPE: &str = "https://storage.azure.com/.default";
63
64/// Resource ID used when obtaining an access token from the metadata endpoint
65///
66/// <https://learn.microsoft.com/en-us/azure/storage/blobs/authorize-access-azure-active-directory#microsoft-authentication-library-msal>
67const AZURE_STORAGE_RESOURCE: &str = "https://storage.azure.com";
68
69#[derive(Debug, Snafu)]
70pub enum Error {
71    #[snafu(display("Error performing token request: {}", source))]
72    TokenRequest { source: crate::client::retry::Error },
73
74    #[snafu(display("Error getting token response body: {}", source))]
75    TokenResponseBody { source: reqwest::Error },
76
77    #[snafu(display("Error reading federated token file "))]
78    FederatedTokenFile,
79
80    #[snafu(display("Invalid Access Key: {}", source))]
81    InvalidAccessKey { source: base64::DecodeError },
82
83    #[snafu(display("'az account get-access-token' command failed: {message}"))]
84    AzureCli { message: String },
85
86    #[snafu(display("Failed to parse azure cli response: {source}"))]
87    AzureCliResponse { source: serde_json::Error },
88
89    #[snafu(display("Generating SAS keys with SAS tokens auth is not supported"))]
90    SASforSASNotSupported,
91}
92
93pub type Result<T, E = Error> = std::result::Result<T, E>;
94
95impl From<Error> for crate::Error {
96    fn from(value: Error) -> Self {
97        Self::Generic {
98            store: STORE,
99            source: Box::new(value),
100        }
101    }
102}
103
104/// A shared Azure Storage Account Key
105#[derive(Debug, Clone, Eq, PartialEq)]
106pub struct AzureAccessKey(Vec<u8>);
107
108impl AzureAccessKey {
109    /// Create a new [`AzureAccessKey`], checking it for validity
110    pub fn try_new(key: &str) -> Result<Self> {
111        let key = BASE64_STANDARD.decode(key).context(InvalidAccessKeySnafu)?;
112        Ok(Self(key))
113    }
114}
115
116/// An Azure storage credential
117#[derive(Debug, Eq, PartialEq)]
118pub enum AzureCredential {
119    /// A shared access key
120    ///
121    /// <https://learn.microsoft.com/en-us/rest/api/storageservices/authorize-with-shared-key>
122    AccessKey(AzureAccessKey),
123    /// A shared access signature
124    ///
125    /// <https://learn.microsoft.com/en-us/rest/api/storageservices/delegate-access-with-shared-access-signature>
126    SASToken(Vec<(String, String)>),
127    /// An authorization token
128    ///
129    /// <https://learn.microsoft.com/en-us/rest/api/storageservices/authorize-with-azure-active-directory>
130    BearerToken(String),
131}
132
133/// A list of known Azure authority hosts
134pub mod authority_hosts {
135    /// China-based Azure Authority Host
136    pub const AZURE_CHINA: &str = "https://login.chinacloudapi.cn";
137    /// Germany-based Azure Authority Host
138    pub const AZURE_GERMANY: &str = "https://login.microsoftonline.de";
139    /// US Government Azure Authority Host
140    pub const AZURE_GOVERNMENT: &str = "https://login.microsoftonline.us";
141    /// Public Cloud Azure Authority Host
142    pub const AZURE_PUBLIC_CLOUD: &str = "https://login.microsoftonline.com";
143}
144
145pub(crate) struct AzureSigner {
146    signing_key: AzureAccessKey,
147    start: DateTime<Utc>,
148    end: DateTime<Utc>,
149    account: String,
150    delegation_key: Option<UserDelegationKey>,
151}
152
153impl AzureSigner {
154    pub fn new(
155        signing_key: AzureAccessKey,
156        account: String,
157        start: DateTime<Utc>,
158        end: DateTime<Utc>,
159        delegation_key: Option<UserDelegationKey>,
160    ) -> Self {
161        Self {
162            signing_key,
163            account,
164            start,
165            end,
166            delegation_key,
167        }
168    }
169
170    pub fn sign(&self, method: &Method, url: &mut Url) -> Result<()> {
171        let (str_to_sign, query_pairs) = match &self.delegation_key {
172            Some(delegation_key) => string_to_sign_user_delegation_sas(
173                url,
174                method,
175                &self.account,
176                &self.start,
177                &self.end,
178                delegation_key,
179            ),
180            None => string_to_sign_service_sas(url, method, &self.account, &self.start, &self.end),
181        };
182        let auth = hmac_sha256(&self.signing_key.0, str_to_sign);
183        url.query_pairs_mut().extend_pairs(query_pairs);
184        url.query_pairs_mut()
185            .append_pair("sig", BASE64_STANDARD.encode(auth).as_str());
186        Ok(())
187    }
188}
189
190fn add_date_and_version_headers(request: &mut Request) {
191    // rfc2822 string should never contain illegal characters
192    let date = Utc::now();
193    let date_str = date.format(RFC1123_FMT).to_string();
194    // we formatted the data string ourselves, so unwrapping should be fine
195    let date_val = HeaderValue::from_str(&date_str).unwrap();
196    request.headers_mut().insert(DATE, date_val);
197    request
198        .headers_mut()
199        .insert(&VERSION, AZURE_VERSION.clone());
200}
201
202/// Authorize a [`Request`] with an [`AzureAuthorizer`]
203#[derive(Debug)]
204pub struct AzureAuthorizer<'a> {
205    credential: &'a AzureCredential,
206    account: &'a str,
207}
208
209impl<'a> AzureAuthorizer<'a> {
210    /// Create a new [`AzureAuthorizer`]
211    pub fn new(credential: &'a AzureCredential, account: &'a str) -> Self {
212        AzureAuthorizer {
213            credential,
214            account,
215        }
216    }
217
218    /// Authorize `request`
219    pub fn authorize(&self, request: &mut Request) {
220        add_date_and_version_headers(request);
221
222        match self.credential {
223            AzureCredential::AccessKey(key) => {
224                let signature = generate_authorization(
225                    request.headers(),
226                    request.url(),
227                    request.method(),
228                    self.account,
229                    key,
230                );
231
232                // "signature" is a base 64 encoded string so it should never
233                // contain illegal characters
234                request.headers_mut().append(
235                    AUTHORIZATION,
236                    HeaderValue::from_str(signature.as_str()).unwrap(),
237                );
238            }
239            AzureCredential::BearerToken(token) => {
240                request.headers_mut().append(
241                    AUTHORIZATION,
242                    HeaderValue::from_str(format!("Bearer {}", token).as_str()).unwrap(),
243                );
244            }
245            AzureCredential::SASToken(query_pairs) => {
246                request
247                    .url_mut()
248                    .query_pairs_mut()
249                    .extend_pairs(query_pairs);
250            }
251        }
252    }
253}
254
255pub(crate) trait CredentialExt {
256    /// Apply authorization to requests against azure storage accounts
257    /// <https://docs.microsoft.com/en-us/rest/api/storageservices/authorize-requests-to-azure-storage>
258    fn with_azure_authorization(
259        self,
260        credential: &Option<impl Deref<Target = AzureCredential>>,
261        account: &str,
262    ) -> Self;
263}
264
265impl CredentialExt for RequestBuilder {
266    fn with_azure_authorization(
267        self,
268        credential: &Option<impl Deref<Target = AzureCredential>>,
269        account: &str,
270    ) -> Self {
271        let (client, request) = self.build_split();
272        let mut request = request.expect("request valid");
273
274        match credential.as_deref() {
275            Some(credential) => {
276                AzureAuthorizer::new(credential, account).authorize(&mut request);
277            }
278            None => {
279                add_date_and_version_headers(&mut request);
280            }
281        }
282
283        Self::from_parts(client, request)
284    }
285}
286
287/// Generate signed key for authorization via access keys
288/// <https://docs.microsoft.com/en-us/rest/api/storageservices/authorize-with-shared-key>
289fn generate_authorization(
290    h: &HeaderMap,
291    u: &Url,
292    method: &Method,
293    account: &str,
294    key: &AzureAccessKey,
295) -> String {
296    let str_to_sign = string_to_sign(h, u, method, account);
297    let auth = hmac_sha256(&key.0, str_to_sign);
298    format!("SharedKey {}:{}", account, BASE64_STANDARD.encode(auth))
299}
300
301fn add_if_exists<'a>(h: &'a HeaderMap, key: &HeaderName) -> &'a str {
302    h.get(key)
303        .map(|s| s.to_str())
304        .transpose()
305        .ok()
306        .flatten()
307        .unwrap_or_default()
308}
309
310fn string_to_sign_sas(
311    u: &Url,
312    method: &Method,
313    account: &str,
314    start: &DateTime<Utc>,
315    end: &DateTime<Utc>,
316) -> (String, String, String, String, String) {
317    // NOTE: for now only blob signing is supported.
318    let signed_resource = "b".to_string();
319
320    // https://learn.microsoft.com/en-us/rest/api/storageservices/create-service-sas#permissions-for-a-directory-container-or-blob
321    let signed_permissions = match *method {
322        // read and list permissions
323        Method::GET => match signed_resource.as_str() {
324            "c" => "rl",
325            "b" => "r",
326            _ => unreachable!(),
327        },
328        // write permissions (also allows crating a new blob in a sub-key)
329        Method::PUT => "w",
330        // delete permissions
331        Method::DELETE => "d",
332        // other methods are not used in any of the current operations
333        _ => "",
334    }
335    .to_string();
336    let signed_start = start.to_rfc3339_opts(SecondsFormat::Secs, true);
337    let signed_expiry = end.to_rfc3339_opts(SecondsFormat::Secs, true);
338    let canonicalized_resource = if u.host_str().unwrap_or_default().contains(account) {
339        format!("/blob/{}{}", account, u.path())
340    } else {
341        // NOTE: in case of the emulator, the account name is not part of the host
342        //      but the path starts with the account name
343        format!("/blob{}", u.path())
344    };
345
346    (
347        signed_resource,
348        signed_permissions,
349        signed_start,
350        signed_expiry,
351        canonicalized_resource,
352    )
353}
354
355/// Create a string to be signed for authorization via [service sas].
356///
357/// [service sas]: https://learn.microsoft.com/en-us/rest/api/storageservices/create-service-sas#version-2020-12-06-and-later
358fn string_to_sign_service_sas(
359    u: &Url,
360    method: &Method,
361    account: &str,
362    start: &DateTime<Utc>,
363    end: &DateTime<Utc>,
364) -> (String, HashMap<&'static str, String>) {
365    let (signed_resource, signed_permissions, signed_start, signed_expiry, canonicalized_resource) =
366        string_to_sign_sas(u, method, account, start, end);
367
368    let string_to_sign = format!(
369        "{}\n{}\n{}\n{}\n{}\n{}\n{}\n{}\n{}\n{}\n{}\n{}\n{}\n{}\n{}\n{}",
370        signed_permissions,
371        signed_start,
372        signed_expiry,
373        canonicalized_resource,
374        "",                               // signed identifier
375        "",                               // signed ip
376        "",                               // signed protocol
377        &AZURE_VERSION.to_str().unwrap(), // signed version
378        signed_resource,                  // signed resource
379        "",                               // signed snapshot time
380        "",                               // signed encryption scope
381        "",                               // rscc - response header: Cache-Control
382        "",                               // rscd - response header: Content-Disposition
383        "",                               // rsce - response header: Content-Encoding
384        "",                               // rscl - response header: Content-Language
385        "",                               // rsct - response header: Content-Type
386    );
387
388    let mut pairs = HashMap::new();
389    pairs.insert("sv", AZURE_VERSION.to_str().unwrap().to_string());
390    pairs.insert("sp", signed_permissions);
391    pairs.insert("st", signed_start);
392    pairs.insert("se", signed_expiry);
393    pairs.insert("sr", signed_resource);
394
395    (string_to_sign, pairs)
396}
397
398/// Create a string to be signed for authorization via [user delegation sas].
399///
400/// [user delegation sas]: https://learn.microsoft.com/en-us/rest/api/storageservices/create-user-delegation-sas#version-2020-12-06-and-later
401fn string_to_sign_user_delegation_sas(
402    u: &Url,
403    method: &Method,
404    account: &str,
405    start: &DateTime<Utc>,
406    end: &DateTime<Utc>,
407    delegation_key: &UserDelegationKey,
408) -> (String, HashMap<&'static str, String>) {
409    let (signed_resource, signed_permissions, signed_start, signed_expiry, canonicalized_resource) =
410        string_to_sign_sas(u, method, account, start, end);
411
412    let string_to_sign = format!(
413        "{}\n{}\n{}\n{}\n{}\n{}\n{}\n{}\n{}\n{}\n{}\n{}\n{}\n{}\n{}\n{}\n{}\n{}\n{}\n{}\n{}\n{}\n{}\n{}",
414        signed_permissions,
415        signed_start,
416        signed_expiry,
417        canonicalized_resource,
418        delegation_key.signed_oid,        // signed key object id
419        delegation_key.signed_tid,        // signed key tenant id
420        delegation_key.signed_start,      // signed key start
421        delegation_key.signed_expiry,     // signed key expiry
422        delegation_key.signed_service,    // signed key service
423        delegation_key.signed_version,    // signed key version
424        "",                               // signed authorized user object id
425        "",                               // signed unauthorized user object id
426        "",                               // signed correlation id
427        "",                               // signed ip
428        "",                               // signed protocol
429        &AZURE_VERSION.to_str().unwrap(), // signed version
430        signed_resource,                  // signed resource
431        "",                               // signed snapshot time
432        "",                               // signed encryption scope
433        "",                               // rscc - response header: Cache-Control
434        "",                               // rscd - response header: Content-Disposition
435        "",                               // rsce - response header: Content-Encoding
436        "",                               // rscl - response header: Content-Language
437        "",                               // rsct - response header: Content-Type
438    );
439
440    let mut pairs = HashMap::new();
441    pairs.insert("sv", AZURE_VERSION.to_str().unwrap().to_string());
442    pairs.insert("sp", signed_permissions);
443    pairs.insert("st", signed_start);
444    pairs.insert("se", signed_expiry);
445    pairs.insert("sr", signed_resource);
446    pairs.insert("skoid", delegation_key.signed_oid.clone());
447    pairs.insert("sktid", delegation_key.signed_tid.clone());
448    pairs.insert("skt", delegation_key.signed_start.clone());
449    pairs.insert("ske", delegation_key.signed_expiry.clone());
450    pairs.insert("sks", delegation_key.signed_service.clone());
451    pairs.insert("skv", delegation_key.signed_version.clone());
452
453    (string_to_sign, pairs)
454}
455
456/// <https://docs.microsoft.com/en-us/rest/api/storageservices/authorize-with-shared-key#constructing-the-signature-string>
457fn string_to_sign(h: &HeaderMap, u: &Url, method: &Method, account: &str) -> String {
458    // content length must only be specified if != 0
459    // this is valid from 2015-02-21
460    let content_length = h
461        .get(&CONTENT_LENGTH)
462        .map(|s| s.to_str())
463        .transpose()
464        .ok()
465        .flatten()
466        .filter(|&v| v != "0")
467        .unwrap_or_default();
468    format!(
469        "{}\n{}\n{}\n{}\n{}\n{}\n{}\n{}\n{}\n{}\n{}\n{}\n{}{}",
470        method.as_ref(),
471        add_if_exists(h, &CONTENT_ENCODING),
472        add_if_exists(h, &CONTENT_LANGUAGE),
473        content_length,
474        add_if_exists(h, &CONTENT_MD5),
475        add_if_exists(h, &CONTENT_TYPE),
476        add_if_exists(h, &DATE),
477        add_if_exists(h, &IF_MODIFIED_SINCE),
478        add_if_exists(h, &IF_MATCH),
479        add_if_exists(h, &IF_NONE_MATCH),
480        add_if_exists(h, &IF_UNMODIFIED_SINCE),
481        add_if_exists(h, &RANGE),
482        canonicalize_header(h),
483        canonicalize_resource(account, u)
484    )
485}
486
487/// <https://docs.microsoft.com/en-us/rest/api/storageservices/authorize-with-shared-key#constructing-the-canonicalized-headers-string>
488fn canonicalize_header(headers: &HeaderMap) -> String {
489    let mut names = headers
490        .iter()
491        .filter(|&(k, _)| (k.as_str().starts_with("x-ms")))
492        // TODO remove unwraps
493        .map(|(k, _)| (k.as_str(), headers.get(k).unwrap().to_str().unwrap()))
494        .collect::<Vec<_>>();
495    names.sort_unstable();
496
497    let mut result = String::new();
498    for (name, value) in names {
499        result.push_str(name);
500        result.push(':');
501        result.push_str(value);
502        result.push('\n');
503    }
504    result
505}
506
507/// <https://docs.microsoft.com/en-us/rest/api/storageservices/authorize-with-shared-key#constructing-the-canonicalized-resource-string>
508fn canonicalize_resource(account: &str, uri: &Url) -> String {
509    let mut can_res: String = String::new();
510    can_res.push('/');
511    can_res.push_str(account);
512    can_res.push_str(uri.path().to_string().as_str());
513    can_res.push('\n');
514
515    // query parameters
516    let query_pairs = uri.query_pairs();
517    {
518        let mut qps: Vec<String> = Vec::new();
519        for (q, _) in query_pairs {
520            if !(qps.iter().any(|x| x == &*q)) {
521                qps.push(q.into_owned());
522            }
523        }
524
525        qps.sort();
526
527        for qparam in qps {
528            // find correct parameter
529            let ret = lexy_sort(query_pairs, &qparam);
530
531            can_res = can_res + &qparam.to_lowercase() + ":";
532
533            for (i, item) in ret.iter().enumerate() {
534                if i > 0 {
535                    can_res.push(',');
536                }
537                can_res.push_str(item);
538            }
539
540            can_res.push('\n');
541        }
542    };
543
544    can_res[0..can_res.len() - 1].to_owned()
545}
546
547fn lexy_sort<'a>(
548    vec: impl Iterator<Item = (Cow<'a, str>, Cow<'a, str>)> + 'a,
549    query_param: &str,
550) -> Vec<Cow<'a, str>> {
551    let mut values = vec
552        .filter(|(k, _)| *k == query_param)
553        .map(|(_, v)| v)
554        .collect::<Vec<_>>();
555    values.sort_unstable();
556    values
557}
558
559/// <https://learn.microsoft.com/en-us/azure/active-directory/develop/v2-oauth2-client-creds-grant-flow#successful-response-1>
560#[derive(Deserialize, Debug)]
561struct OAuthTokenResponse {
562    access_token: String,
563    expires_in: u64,
564}
565
566/// Encapsulates the logic to perform an OAuth token challenge
567///
568/// <https://learn.microsoft.com/en-us/azure/active-directory/develop/v2-oauth2-client-creds-grant-flow#first-case-access-token-request-with-a-shared-secret>
569#[derive(Debug)]
570pub struct ClientSecretOAuthProvider {
571    token_url: String,
572    client_id: String,
573    client_secret: String,
574}
575
576impl ClientSecretOAuthProvider {
577    /// Create a new [`ClientSecretOAuthProvider`] for an azure backed store
578    pub fn new(
579        client_id: String,
580        client_secret: String,
581        tenant_id: impl AsRef<str>,
582        authority_host: Option<String>,
583    ) -> Self {
584        let authority_host =
585            authority_host.unwrap_or_else(|| authority_hosts::AZURE_PUBLIC_CLOUD.to_owned());
586
587        Self {
588            token_url: format!(
589                "{}/{}/oauth2/v2.0/token",
590                authority_host,
591                tenant_id.as_ref()
592            ),
593            client_id,
594            client_secret,
595        }
596    }
597}
598
599#[async_trait::async_trait]
600impl TokenProvider for ClientSecretOAuthProvider {
601    type Credential = AzureCredential;
602
603    /// Fetch a token
604    async fn fetch_token(
605        &self,
606        client: &Client,
607        retry: &RetryConfig,
608    ) -> crate::Result<TemporaryToken<Arc<AzureCredential>>> {
609        let response: OAuthTokenResponse = client
610            .request(Method::POST, &self.token_url)
611            .header(ACCEPT, HeaderValue::from_static(CONTENT_TYPE_JSON))
612            .form(&[
613                ("client_id", self.client_id.as_str()),
614                ("client_secret", self.client_secret.as_str()),
615                ("scope", AZURE_STORAGE_SCOPE),
616                ("grant_type", "client_credentials"),
617            ])
618            .retryable(retry)
619            .idempotent(true)
620            .send()
621            .await
622            .context(TokenRequestSnafu)?
623            .json()
624            .await
625            .context(TokenResponseBodySnafu)?;
626
627        Ok(TemporaryToken {
628            token: Arc::new(AzureCredential::BearerToken(response.access_token)),
629            expiry: Some(Instant::now() + Duration::from_secs(response.expires_in)),
630        })
631    }
632}
633
634fn expires_on_string<'de, D>(deserializer: D) -> std::result::Result<Instant, D::Error>
635where
636    D: serde::de::Deserializer<'de>,
637{
638    let v = String::deserialize(deserializer)?;
639    let v = v.parse::<u64>().map_err(serde::de::Error::custom)?;
640    let now = SystemTime::now()
641        .duration_since(SystemTime::UNIX_EPOCH)
642        .map_err(serde::de::Error::custom)?;
643
644    Ok(Instant::now() + Duration::from_secs(v.saturating_sub(now.as_secs())))
645}
646
647/// NOTE: expires_on is a String version of unix epoch time, not an integer.
648/// <https://learn.microsoft.com/en-gb/azure/active-directory/managed-identities-azure-resources/how-to-use-vm-token#get-a-token-using-http>
649/// <https://learn.microsoft.com/en-us/azure/app-service/overview-managed-identity?tabs=portal%2Chttp#connect-to-azure-services-in-app-code>
650#[derive(Debug, Clone, Deserialize)]
651struct ImdsTokenResponse {
652    pub access_token: String,
653    #[serde(deserialize_with = "expires_on_string")]
654    pub expires_on: Instant,
655}
656
657/// Attempts authentication using a managed identity that has been assigned to the deployment environment.
658///
659/// This authentication type works in Azure VMs, App Service and Azure Functions applications, as well as the Azure Cloud Shell
660/// <https://learn.microsoft.com/en-gb/azure/active-directory/managed-identities-azure-resources/how-to-use-vm-token#get-a-token-using-http>
661#[derive(Debug)]
662pub struct ImdsManagedIdentityProvider {
663    msi_endpoint: String,
664    client_id: Option<String>,
665    object_id: Option<String>,
666    msi_res_id: Option<String>,
667}
668
669impl ImdsManagedIdentityProvider {
670    /// Create a new [`ImdsManagedIdentityProvider`] for an azure backed store
671    pub fn new(
672        client_id: Option<String>,
673        object_id: Option<String>,
674        msi_res_id: Option<String>,
675        msi_endpoint: Option<String>,
676    ) -> Self {
677        let msi_endpoint = msi_endpoint
678            .unwrap_or_else(|| "http://169.254.169.254/metadata/identity/oauth2/token".to_owned());
679
680        Self {
681            msi_endpoint,
682            client_id,
683            object_id,
684            msi_res_id,
685        }
686    }
687}
688
689#[async_trait::async_trait]
690impl TokenProvider for ImdsManagedIdentityProvider {
691    type Credential = AzureCredential;
692
693    /// Fetch a token
694    async fn fetch_token(
695        &self,
696        client: &Client,
697        retry: &RetryConfig,
698    ) -> crate::Result<TemporaryToken<Arc<AzureCredential>>> {
699        let mut query_items = vec![
700            ("api-version", MSI_API_VERSION),
701            ("resource", AZURE_STORAGE_RESOURCE),
702        ];
703
704        let mut identity = None;
705        if let Some(client_id) = &self.client_id {
706            identity = Some(("client_id", client_id));
707        }
708        if let Some(object_id) = &self.object_id {
709            identity = Some(("object_id", object_id));
710        }
711        if let Some(msi_res_id) = &self.msi_res_id {
712            identity = Some(("msi_res_id", msi_res_id));
713        }
714        if let Some((key, value)) = identity {
715            query_items.push((key, value));
716        }
717
718        let mut builder = client
719            .request(Method::GET, &self.msi_endpoint)
720            .header("metadata", "true")
721            .query(&query_items);
722
723        if let Ok(val) = std::env::var(MSI_SECRET_ENV_KEY) {
724            builder = builder.header("x-identity-header", val);
725        };
726
727        let response: ImdsTokenResponse = builder
728            .send_retry(retry)
729            .await
730            .context(TokenRequestSnafu)?
731            .json()
732            .await
733            .context(TokenResponseBodySnafu)?;
734
735        Ok(TemporaryToken {
736            token: Arc::new(AzureCredential::BearerToken(response.access_token)),
737            expiry: Some(response.expires_on),
738        })
739    }
740}
741
742/// Credential for using workload identity federation
743///
744/// <https://learn.microsoft.com/en-us/azure/active-directory/develop/workload-identity-federation>
745#[derive(Debug)]
746pub struct WorkloadIdentityOAuthProvider {
747    token_url: String,
748    client_id: String,
749    federated_token_file: String,
750}
751
752impl WorkloadIdentityOAuthProvider {
753    /// Create a new [`WorkloadIdentityOAuthProvider`] for an azure backed store
754    pub fn new(
755        client_id: impl Into<String>,
756        federated_token_file: impl Into<String>,
757        tenant_id: impl AsRef<str>,
758        authority_host: Option<String>,
759    ) -> Self {
760        let authority_host =
761            authority_host.unwrap_or_else(|| authority_hosts::AZURE_PUBLIC_CLOUD.to_owned());
762
763        Self {
764            token_url: format!(
765                "{}/{}/oauth2/v2.0/token",
766                authority_host,
767                tenant_id.as_ref()
768            ),
769            client_id: client_id.into(),
770            federated_token_file: federated_token_file.into(),
771        }
772    }
773}
774
775#[async_trait::async_trait]
776impl TokenProvider for WorkloadIdentityOAuthProvider {
777    type Credential = AzureCredential;
778
779    /// Fetch a token
780    async fn fetch_token(
781        &self,
782        client: &Client,
783        retry: &RetryConfig,
784    ) -> crate::Result<TemporaryToken<Arc<AzureCredential>>> {
785        let token_str = std::fs::read_to_string(&self.federated_token_file)
786            .map_err(|_| Error::FederatedTokenFile)?;
787
788        // https://learn.microsoft.com/en-us/azure/active-directory/develop/v2-oauth2-client-creds-grant-flow#third-case-access-token-request-with-a-federated-credential
789        let response: OAuthTokenResponse = client
790            .request(Method::POST, &self.token_url)
791            .header(ACCEPT, HeaderValue::from_static(CONTENT_TYPE_JSON))
792            .form(&[
793                ("client_id", self.client_id.as_str()),
794                (
795                    "client_assertion_type",
796                    "urn:ietf:params:oauth:client-assertion-type:jwt-bearer",
797                ),
798                ("client_assertion", token_str.as_str()),
799                ("scope", AZURE_STORAGE_SCOPE),
800                ("grant_type", "client_credentials"),
801            ])
802            .retryable(retry)
803            .idempotent(true)
804            .send()
805            .await
806            .context(TokenRequestSnafu)?
807            .json()
808            .await
809            .context(TokenResponseBodySnafu)?;
810
811        Ok(TemporaryToken {
812            token: Arc::new(AzureCredential::BearerToken(response.access_token)),
813            expiry: Some(Instant::now() + Duration::from_secs(response.expires_in)),
814        })
815    }
816}
817
818mod az_cli_date_format {
819    use chrono::{DateTime, TimeZone};
820    use serde::{self, Deserialize, Deserializer};
821
822    pub fn deserialize<'de, D>(deserializer: D) -> Result<DateTime<chrono::Local>, D::Error>
823    where
824        D: Deserializer<'de>,
825    {
826        let s = String::deserialize(deserializer)?;
827        // expiresOn from azure cli uses the local timezone
828        let date = chrono::NaiveDateTime::parse_from_str(&s, "%Y-%m-%d %H:%M:%S.%6f")
829            .map_err(serde::de::Error::custom)?;
830        chrono::Local
831            .from_local_datetime(&date)
832            .single()
833            .ok_or(serde::de::Error::custom(
834                "azure cli returned ambiguous expiry date",
835            ))
836    }
837}
838
839#[derive(Debug, Clone, Deserialize)]
840#[serde(rename_all = "camelCase")]
841struct AzureCliTokenResponse {
842    pub access_token: String,
843    #[serde(with = "az_cli_date_format")]
844    pub expires_on: DateTime<chrono::Local>,
845    pub token_type: String,
846}
847
848#[derive(Default, Debug)]
849pub struct AzureCliCredential {
850    cache: TokenCache<Arc<AzureCredential>>,
851}
852
853impl AzureCliCredential {
854    pub fn new() -> Self {
855        Self::default()
856    }
857
858    /// Fetch a token
859    async fn fetch_token(&self) -> Result<TemporaryToken<Arc<AzureCredential>>> {
860        // on window az is a cmd and it should be called like this
861        // see https://doc.rust-lang.org/nightly/std/process/struct.Command.html
862        let program = if cfg!(target_os = "windows") {
863            "cmd"
864        } else {
865            "az"
866        };
867        let mut args = Vec::new();
868        if cfg!(target_os = "windows") {
869            args.push("/C");
870            args.push("az");
871        }
872        args.push("account");
873        args.push("get-access-token");
874        args.push("--output");
875        args.push("json");
876        args.push("--scope");
877        args.push(AZURE_STORAGE_SCOPE);
878
879        match Command::new(program).args(args).output() {
880            Ok(az_output) if az_output.status.success() => {
881                let output = str::from_utf8(&az_output.stdout).map_err(|_| Error::AzureCli {
882                    message: "az response is not a valid utf-8 string".to_string(),
883                })?;
884
885                let token_response = serde_json::from_str::<AzureCliTokenResponse>(output)
886                    .context(AzureCliResponseSnafu)?;
887                if !token_response.token_type.eq_ignore_ascii_case("bearer") {
888                    return Err(Error::AzureCli {
889                        message: format!(
890                            "got unexpected token type from azure cli: {0}",
891                            token_response.token_type
892                        ),
893                    });
894                }
895                let duration =
896                    token_response.expires_on.naive_local() - chrono::Local::now().naive_local();
897                Ok(TemporaryToken {
898                    token: Arc::new(AzureCredential::BearerToken(token_response.access_token)),
899                    expiry: Some(
900                        Instant::now()
901                            + duration.to_std().map_err(|_| Error::AzureCli {
902                                message: "az returned invalid lifetime".to_string(),
903                            })?,
904                    ),
905                })
906            }
907            Ok(az_output) => {
908                let message = String::from_utf8_lossy(&az_output.stderr);
909                Err(Error::AzureCli {
910                    message: message.into(),
911                })
912            }
913            Err(e) => match e.kind() {
914                std::io::ErrorKind::NotFound => Err(Error::AzureCli {
915                    message: "Azure Cli not installed".into(),
916                }),
917                error_kind => Err(Error::AzureCli {
918                    message: format!("io error: {error_kind:?}"),
919                }),
920            },
921        }
922    }
923}
924
925#[async_trait]
926impl CredentialProvider for AzureCliCredential {
927    type Credential = AzureCredential;
928
929    async fn get_credential(&self) -> crate::Result<Arc<Self::Credential>> {
930        Ok(self.cache.get_or_insert_with(|| self.fetch_token()).await?)
931    }
932}
933
934#[cfg(test)]
935mod tests {
936    use futures::executor::block_on;
937    use http_body_util::BodyExt;
938    use hyper::{Response, StatusCode};
939    use reqwest::{Client, Method};
940    use tempfile::NamedTempFile;
941
942    use super::*;
943    use crate::azure::MicrosoftAzureBuilder;
944    use crate::client::mock_server::MockServer;
945    use crate::{ObjectStore, Path};
946
947    #[tokio::test]
948    async fn test_managed_identity() {
949        let server = MockServer::new().await;
950
951        std::env::set_var(MSI_SECRET_ENV_KEY, "env-secret");
952
953        let endpoint = server.url();
954        let client = Client::new();
955        let retry_config = RetryConfig::default();
956
957        // Test IMDS
958        server.push_fn(|req| {
959            assert_eq!(req.uri().path(), "/metadata/identity/oauth2/token");
960            assert!(req.uri().query().unwrap().contains("client_id=client_id"));
961            assert_eq!(req.method(), &Method::GET);
962            let t = req
963                .headers()
964                .get("x-identity-header")
965                .unwrap()
966                .to_str()
967                .unwrap();
968            assert_eq!(t, "env-secret");
969            let t = req.headers().get("metadata").unwrap().to_str().unwrap();
970            assert_eq!(t, "true");
971            Response::new(
972                r#"
973            {
974                "access_token": "TOKEN",
975                "refresh_token": "",
976                "expires_in": "3599",
977                "expires_on": "1506484173",
978                "not_before": "1506480273",
979                "resource": "https://management.azure.com/",
980                "token_type": "Bearer"
981              }
982            "#
983                .to_string(),
984            )
985        });
986
987        let credential = ImdsManagedIdentityProvider::new(
988            Some("client_id".into()),
989            None,
990            None,
991            Some(format!("{endpoint}/metadata/identity/oauth2/token")),
992        );
993
994        let token = credential
995            .fetch_token(&client, &retry_config)
996            .await
997            .unwrap();
998
999        assert_eq!(
1000            token.token.as_ref(),
1001            &AzureCredential::BearerToken("TOKEN".into())
1002        );
1003    }
1004
1005    #[tokio::test]
1006    async fn test_workload_identity() {
1007        let server = MockServer::new().await;
1008        let tokenfile = NamedTempFile::new().unwrap();
1009        let tenant = "tenant";
1010        std::fs::write(tokenfile.path(), "federated-token").unwrap();
1011
1012        let endpoint = server.url();
1013        let client = Client::new();
1014        let retry_config = RetryConfig::default();
1015
1016        // Test IMDS
1017        server.push_fn(move |req| {
1018            assert_eq!(req.uri().path(), format!("/{tenant}/oauth2/v2.0/token"));
1019            assert_eq!(req.method(), &Method::POST);
1020            let body = block_on(async move { req.into_body().collect().await.unwrap().to_bytes() });
1021            let body = String::from_utf8(body.to_vec()).unwrap();
1022            assert!(body.contains("federated-token"));
1023            Response::new(
1024                r#"
1025            {
1026                "access_token": "TOKEN",
1027                "refresh_token": "",
1028                "expires_in": 3599,
1029                "expires_on": "1506484173",
1030                "not_before": "1506480273",
1031                "resource": "https://management.azure.com/",
1032                "token_type": "Bearer"
1033              }
1034            "#
1035                .to_string(),
1036            )
1037        });
1038
1039        let credential = WorkloadIdentityOAuthProvider::new(
1040            "client_id",
1041            tokenfile.path().to_str().unwrap(),
1042            tenant,
1043            Some(endpoint.to_string()),
1044        );
1045
1046        let token = credential
1047            .fetch_token(&client, &retry_config)
1048            .await
1049            .unwrap();
1050
1051        assert_eq!(
1052            token.token.as_ref(),
1053            &AzureCredential::BearerToken("TOKEN".into())
1054        );
1055    }
1056
1057    #[tokio::test]
1058    async fn test_no_credentials() {
1059        let server = MockServer::new().await;
1060
1061        let endpoint = server.url();
1062        let store = MicrosoftAzureBuilder::new()
1063            .with_account("test")
1064            .with_container_name("test")
1065            .with_allow_http(true)
1066            .with_bearer_token_authorization("token")
1067            .with_endpoint(endpoint.to_string())
1068            .with_skip_signature(true)
1069            .build()
1070            .unwrap();
1071
1072        server.push_fn(|req| {
1073            assert_eq!(req.method(), &Method::GET);
1074            assert!(req.headers().get("Authorization").is_none());
1075            Response::builder()
1076                .status(StatusCode::NOT_FOUND)
1077                .body("not found".to_string())
1078                .unwrap()
1079        });
1080
1081        let path = Path::from("file.txt");
1082        match store.get(&path).await {
1083            Err(crate::Error::NotFound { .. }) => {}
1084            _ => {
1085                panic!("unexpected response");
1086            }
1087        }
1088    }
1089}