1use crate::azure::STORE;
19use crate::client::retry::RetryExt;
20use crate::client::token::{TemporaryToken, TokenCache};
21use crate::client::{CredentialProvider, TokenProvider};
22use crate::util::hmac_sha256;
23use crate::RetryConfig;
24use async_trait::async_trait;
25use base64::prelude::BASE64_STANDARD;
26use base64::Engine;
27use chrono::{DateTime, SecondsFormat, Utc};
28use reqwest::header::{
29 HeaderMap, HeaderName, HeaderValue, ACCEPT, AUTHORIZATION, CONTENT_ENCODING, CONTENT_LANGUAGE,
30 CONTENT_LENGTH, CONTENT_TYPE, DATE, IF_MATCH, IF_MODIFIED_SINCE, IF_NONE_MATCH,
31 IF_UNMODIFIED_SINCE, RANGE,
32};
33use reqwest::{Client, Method, Request, RequestBuilder};
34use serde::Deserialize;
35use snafu::{ResultExt, Snafu};
36use std::borrow::Cow;
37use std::collections::HashMap;
38use std::fmt::Debug;
39use std::ops::Deref;
40use std::process::Command;
41use std::str;
42use std::sync::Arc;
43use std::time::{Duration, Instant, SystemTime};
44use url::Url;
45
46use super::client::UserDelegationKey;
47
48static AZURE_VERSION: HeaderValue = HeaderValue::from_static("2023-11-03");
49static VERSION: HeaderName = HeaderName::from_static("x-ms-version");
50pub(crate) static BLOB_TYPE: HeaderName = HeaderName::from_static("x-ms-blob-type");
51pub(crate) static DELETE_SNAPSHOTS: HeaderName = HeaderName::from_static("x-ms-delete-snapshots");
52pub(crate) static COPY_SOURCE: HeaderName = HeaderName::from_static("x-ms-copy-source");
53static CONTENT_MD5: HeaderName = HeaderName::from_static("content-md5");
54pub(crate) const RFC1123_FMT: &str = "%a, %d %h %Y %T GMT";
55const CONTENT_TYPE_JSON: &str = "application/json";
56const MSI_SECRET_ENV_KEY: &str = "IDENTITY_HEADER";
57const MSI_API_VERSION: &str = "2019-08-01";
58
59const AZURE_STORAGE_SCOPE: &str = "https://storage.azure.com/.default";
63
64const AZURE_STORAGE_RESOURCE: &str = "https://storage.azure.com";
68
69#[derive(Debug, Snafu)]
70pub enum Error {
71 #[snafu(display("Error performing token request: {}", source))]
72 TokenRequest { source: crate::client::retry::Error },
73
74 #[snafu(display("Error getting token response body: {}", source))]
75 TokenResponseBody { source: reqwest::Error },
76
77 #[snafu(display("Error reading federated token file "))]
78 FederatedTokenFile,
79
80 #[snafu(display("Invalid Access Key: {}", source))]
81 InvalidAccessKey { source: base64::DecodeError },
82
83 #[snafu(display("'az account get-access-token' command failed: {message}"))]
84 AzureCli { message: String },
85
86 #[snafu(display("Failed to parse azure cli response: {source}"))]
87 AzureCliResponse { source: serde_json::Error },
88
89 #[snafu(display("Generating SAS keys with SAS tokens auth is not supported"))]
90 SASforSASNotSupported,
91}
92
93pub type Result<T, E = Error> = std::result::Result<T, E>;
94
95impl From<Error> for crate::Error {
96 fn from(value: Error) -> Self {
97 Self::Generic {
98 store: STORE,
99 source: Box::new(value),
100 }
101 }
102}
103
104#[derive(Debug, Clone, Eq, PartialEq)]
106pub struct AzureAccessKey(Vec<u8>);
107
108impl AzureAccessKey {
109 pub fn try_new(key: &str) -> Result<Self> {
111 let key = BASE64_STANDARD.decode(key).context(InvalidAccessKeySnafu)?;
112 Ok(Self(key))
113 }
114}
115
116#[derive(Debug, Eq, PartialEq)]
118pub enum AzureCredential {
119 AccessKey(AzureAccessKey),
123 SASToken(Vec<(String, String)>),
127 BearerToken(String),
131}
132
133pub mod authority_hosts {
135 pub const AZURE_CHINA: &str = "https://login.chinacloudapi.cn";
137 pub const AZURE_GERMANY: &str = "https://login.microsoftonline.de";
139 pub const AZURE_GOVERNMENT: &str = "https://login.microsoftonline.us";
141 pub const AZURE_PUBLIC_CLOUD: &str = "https://login.microsoftonline.com";
143}
144
145pub(crate) struct AzureSigner {
146 signing_key: AzureAccessKey,
147 start: DateTime<Utc>,
148 end: DateTime<Utc>,
149 account: String,
150 delegation_key: Option<UserDelegationKey>,
151}
152
153impl AzureSigner {
154 pub fn new(
155 signing_key: AzureAccessKey,
156 account: String,
157 start: DateTime<Utc>,
158 end: DateTime<Utc>,
159 delegation_key: Option<UserDelegationKey>,
160 ) -> Self {
161 Self {
162 signing_key,
163 account,
164 start,
165 end,
166 delegation_key,
167 }
168 }
169
170 pub fn sign(&self, method: &Method, url: &mut Url) -> Result<()> {
171 let (str_to_sign, query_pairs) = match &self.delegation_key {
172 Some(delegation_key) => string_to_sign_user_delegation_sas(
173 url,
174 method,
175 &self.account,
176 &self.start,
177 &self.end,
178 delegation_key,
179 ),
180 None => string_to_sign_service_sas(url, method, &self.account, &self.start, &self.end),
181 };
182 let auth = hmac_sha256(&self.signing_key.0, str_to_sign);
183 url.query_pairs_mut().extend_pairs(query_pairs);
184 url.query_pairs_mut()
185 .append_pair("sig", BASE64_STANDARD.encode(auth).as_str());
186 Ok(())
187 }
188}
189
190fn add_date_and_version_headers(request: &mut Request) {
191 let date = Utc::now();
193 let date_str = date.format(RFC1123_FMT).to_string();
194 let date_val = HeaderValue::from_str(&date_str).unwrap();
196 request.headers_mut().insert(DATE, date_val);
197 request
198 .headers_mut()
199 .insert(&VERSION, AZURE_VERSION.clone());
200}
201
202#[derive(Debug)]
204pub struct AzureAuthorizer<'a> {
205 credential: &'a AzureCredential,
206 account: &'a str,
207}
208
209impl<'a> AzureAuthorizer<'a> {
210 pub fn new(credential: &'a AzureCredential, account: &'a str) -> Self {
212 AzureAuthorizer {
213 credential,
214 account,
215 }
216 }
217
218 pub fn authorize(&self, request: &mut Request) {
220 add_date_and_version_headers(request);
221
222 match self.credential {
223 AzureCredential::AccessKey(key) => {
224 let signature = generate_authorization(
225 request.headers(),
226 request.url(),
227 request.method(),
228 self.account,
229 key,
230 );
231
232 request.headers_mut().append(
235 AUTHORIZATION,
236 HeaderValue::from_str(signature.as_str()).unwrap(),
237 );
238 }
239 AzureCredential::BearerToken(token) => {
240 request.headers_mut().append(
241 AUTHORIZATION,
242 HeaderValue::from_str(format!("Bearer {}", token).as_str()).unwrap(),
243 );
244 }
245 AzureCredential::SASToken(query_pairs) => {
246 request
247 .url_mut()
248 .query_pairs_mut()
249 .extend_pairs(query_pairs);
250 }
251 }
252 }
253}
254
255pub(crate) trait CredentialExt {
256 fn with_azure_authorization(
259 self,
260 credential: &Option<impl Deref<Target = AzureCredential>>,
261 account: &str,
262 ) -> Self;
263}
264
265impl CredentialExt for RequestBuilder {
266 fn with_azure_authorization(
267 self,
268 credential: &Option<impl Deref<Target = AzureCredential>>,
269 account: &str,
270 ) -> Self {
271 let (client, request) = self.build_split();
272 let mut request = request.expect("request valid");
273
274 match credential.as_deref() {
275 Some(credential) => {
276 AzureAuthorizer::new(credential, account).authorize(&mut request);
277 }
278 None => {
279 add_date_and_version_headers(&mut request);
280 }
281 }
282
283 Self::from_parts(client, request)
284 }
285}
286
287fn generate_authorization(
290 h: &HeaderMap,
291 u: &Url,
292 method: &Method,
293 account: &str,
294 key: &AzureAccessKey,
295) -> String {
296 let str_to_sign = string_to_sign(h, u, method, account);
297 let auth = hmac_sha256(&key.0, str_to_sign);
298 format!("SharedKey {}:{}", account, BASE64_STANDARD.encode(auth))
299}
300
301fn add_if_exists<'a>(h: &'a HeaderMap, key: &HeaderName) -> &'a str {
302 h.get(key)
303 .map(|s| s.to_str())
304 .transpose()
305 .ok()
306 .flatten()
307 .unwrap_or_default()
308}
309
310fn string_to_sign_sas(
311 u: &Url,
312 method: &Method,
313 account: &str,
314 start: &DateTime<Utc>,
315 end: &DateTime<Utc>,
316) -> (String, String, String, String, String) {
317 let signed_resource = "b".to_string();
319
320 let signed_permissions = match *method {
322 Method::GET => match signed_resource.as_str() {
324 "c" => "rl",
325 "b" => "r",
326 _ => unreachable!(),
327 },
328 Method::PUT => "w",
330 Method::DELETE => "d",
332 _ => "",
334 }
335 .to_string();
336 let signed_start = start.to_rfc3339_opts(SecondsFormat::Secs, true);
337 let signed_expiry = end.to_rfc3339_opts(SecondsFormat::Secs, true);
338 let canonicalized_resource = if u.host_str().unwrap_or_default().contains(account) {
339 format!("/blob/{}{}", account, u.path())
340 } else {
341 format!("/blob{}", u.path())
344 };
345
346 (
347 signed_resource,
348 signed_permissions,
349 signed_start,
350 signed_expiry,
351 canonicalized_resource,
352 )
353}
354
355fn string_to_sign_service_sas(
359 u: &Url,
360 method: &Method,
361 account: &str,
362 start: &DateTime<Utc>,
363 end: &DateTime<Utc>,
364) -> (String, HashMap<&'static str, String>) {
365 let (signed_resource, signed_permissions, signed_start, signed_expiry, canonicalized_resource) =
366 string_to_sign_sas(u, method, account, start, end);
367
368 let string_to_sign = format!(
369 "{}\n{}\n{}\n{}\n{}\n{}\n{}\n{}\n{}\n{}\n{}\n{}\n{}\n{}\n{}\n{}",
370 signed_permissions,
371 signed_start,
372 signed_expiry,
373 canonicalized_resource,
374 "", "", "", &AZURE_VERSION.to_str().unwrap(), signed_resource, "", "", "", "", "", "", "", );
387
388 let mut pairs = HashMap::new();
389 pairs.insert("sv", AZURE_VERSION.to_str().unwrap().to_string());
390 pairs.insert("sp", signed_permissions);
391 pairs.insert("st", signed_start);
392 pairs.insert("se", signed_expiry);
393 pairs.insert("sr", signed_resource);
394
395 (string_to_sign, pairs)
396}
397
398fn string_to_sign_user_delegation_sas(
402 u: &Url,
403 method: &Method,
404 account: &str,
405 start: &DateTime<Utc>,
406 end: &DateTime<Utc>,
407 delegation_key: &UserDelegationKey,
408) -> (String, HashMap<&'static str, String>) {
409 let (signed_resource, signed_permissions, signed_start, signed_expiry, canonicalized_resource) =
410 string_to_sign_sas(u, method, account, start, end);
411
412 let string_to_sign = format!(
413 "{}\n{}\n{}\n{}\n{}\n{}\n{}\n{}\n{}\n{}\n{}\n{}\n{}\n{}\n{}\n{}\n{}\n{}\n{}\n{}\n{}\n{}\n{}\n{}",
414 signed_permissions,
415 signed_start,
416 signed_expiry,
417 canonicalized_resource,
418 delegation_key.signed_oid, delegation_key.signed_tid, delegation_key.signed_start, delegation_key.signed_expiry, delegation_key.signed_service, delegation_key.signed_version, "", "", "", "", "", &AZURE_VERSION.to_str().unwrap(), signed_resource, "", "", "", "", "", "", "", );
439
440 let mut pairs = HashMap::new();
441 pairs.insert("sv", AZURE_VERSION.to_str().unwrap().to_string());
442 pairs.insert("sp", signed_permissions);
443 pairs.insert("st", signed_start);
444 pairs.insert("se", signed_expiry);
445 pairs.insert("sr", signed_resource);
446 pairs.insert("skoid", delegation_key.signed_oid.clone());
447 pairs.insert("sktid", delegation_key.signed_tid.clone());
448 pairs.insert("skt", delegation_key.signed_start.clone());
449 pairs.insert("ske", delegation_key.signed_expiry.clone());
450 pairs.insert("sks", delegation_key.signed_service.clone());
451 pairs.insert("skv", delegation_key.signed_version.clone());
452
453 (string_to_sign, pairs)
454}
455
456fn string_to_sign(h: &HeaderMap, u: &Url, method: &Method, account: &str) -> String {
458 let content_length = h
461 .get(&CONTENT_LENGTH)
462 .map(|s| s.to_str())
463 .transpose()
464 .ok()
465 .flatten()
466 .filter(|&v| v != "0")
467 .unwrap_or_default();
468 format!(
469 "{}\n{}\n{}\n{}\n{}\n{}\n{}\n{}\n{}\n{}\n{}\n{}\n{}{}",
470 method.as_ref(),
471 add_if_exists(h, &CONTENT_ENCODING),
472 add_if_exists(h, &CONTENT_LANGUAGE),
473 content_length,
474 add_if_exists(h, &CONTENT_MD5),
475 add_if_exists(h, &CONTENT_TYPE),
476 add_if_exists(h, &DATE),
477 add_if_exists(h, &IF_MODIFIED_SINCE),
478 add_if_exists(h, &IF_MATCH),
479 add_if_exists(h, &IF_NONE_MATCH),
480 add_if_exists(h, &IF_UNMODIFIED_SINCE),
481 add_if_exists(h, &RANGE),
482 canonicalize_header(h),
483 canonicalize_resource(account, u)
484 )
485}
486
487fn canonicalize_header(headers: &HeaderMap) -> String {
489 let mut names = headers
490 .iter()
491 .filter(|&(k, _)| (k.as_str().starts_with("x-ms")))
492 .map(|(k, _)| (k.as_str(), headers.get(k).unwrap().to_str().unwrap()))
494 .collect::<Vec<_>>();
495 names.sort_unstable();
496
497 let mut result = String::new();
498 for (name, value) in names {
499 result.push_str(name);
500 result.push(':');
501 result.push_str(value);
502 result.push('\n');
503 }
504 result
505}
506
507fn canonicalize_resource(account: &str, uri: &Url) -> String {
509 let mut can_res: String = String::new();
510 can_res.push('/');
511 can_res.push_str(account);
512 can_res.push_str(uri.path().to_string().as_str());
513 can_res.push('\n');
514
515 let query_pairs = uri.query_pairs();
517 {
518 let mut qps: Vec<String> = Vec::new();
519 for (q, _) in query_pairs {
520 if !(qps.iter().any(|x| x == &*q)) {
521 qps.push(q.into_owned());
522 }
523 }
524
525 qps.sort();
526
527 for qparam in qps {
528 let ret = lexy_sort(query_pairs, &qparam);
530
531 can_res = can_res + &qparam.to_lowercase() + ":";
532
533 for (i, item) in ret.iter().enumerate() {
534 if i > 0 {
535 can_res.push(',');
536 }
537 can_res.push_str(item);
538 }
539
540 can_res.push('\n');
541 }
542 };
543
544 can_res[0..can_res.len() - 1].to_owned()
545}
546
547fn lexy_sort<'a>(
548 vec: impl Iterator<Item = (Cow<'a, str>, Cow<'a, str>)> + 'a,
549 query_param: &str,
550) -> Vec<Cow<'a, str>> {
551 let mut values = vec
552 .filter(|(k, _)| *k == query_param)
553 .map(|(_, v)| v)
554 .collect::<Vec<_>>();
555 values.sort_unstable();
556 values
557}
558
559#[derive(Deserialize, Debug)]
561struct OAuthTokenResponse {
562 access_token: String,
563 expires_in: u64,
564}
565
566#[derive(Debug)]
570pub struct ClientSecretOAuthProvider {
571 token_url: String,
572 client_id: String,
573 client_secret: String,
574}
575
576impl ClientSecretOAuthProvider {
577 pub fn new(
579 client_id: String,
580 client_secret: String,
581 tenant_id: impl AsRef<str>,
582 authority_host: Option<String>,
583 ) -> Self {
584 let authority_host =
585 authority_host.unwrap_or_else(|| authority_hosts::AZURE_PUBLIC_CLOUD.to_owned());
586
587 Self {
588 token_url: format!(
589 "{}/{}/oauth2/v2.0/token",
590 authority_host,
591 tenant_id.as_ref()
592 ),
593 client_id,
594 client_secret,
595 }
596 }
597}
598
599#[async_trait::async_trait]
600impl TokenProvider for ClientSecretOAuthProvider {
601 type Credential = AzureCredential;
602
603 async fn fetch_token(
605 &self,
606 client: &Client,
607 retry: &RetryConfig,
608 ) -> crate::Result<TemporaryToken<Arc<AzureCredential>>> {
609 let response: OAuthTokenResponse = client
610 .request(Method::POST, &self.token_url)
611 .header(ACCEPT, HeaderValue::from_static(CONTENT_TYPE_JSON))
612 .form(&[
613 ("client_id", self.client_id.as_str()),
614 ("client_secret", self.client_secret.as_str()),
615 ("scope", AZURE_STORAGE_SCOPE),
616 ("grant_type", "client_credentials"),
617 ])
618 .retryable(retry)
619 .idempotent(true)
620 .send()
621 .await
622 .context(TokenRequestSnafu)?
623 .json()
624 .await
625 .context(TokenResponseBodySnafu)?;
626
627 Ok(TemporaryToken {
628 token: Arc::new(AzureCredential::BearerToken(response.access_token)),
629 expiry: Some(Instant::now() + Duration::from_secs(response.expires_in)),
630 })
631 }
632}
633
634fn expires_on_string<'de, D>(deserializer: D) -> std::result::Result<Instant, D::Error>
635where
636 D: serde::de::Deserializer<'de>,
637{
638 let v = String::deserialize(deserializer)?;
639 let v = v.parse::<u64>().map_err(serde::de::Error::custom)?;
640 let now = SystemTime::now()
641 .duration_since(SystemTime::UNIX_EPOCH)
642 .map_err(serde::de::Error::custom)?;
643
644 Ok(Instant::now() + Duration::from_secs(v.saturating_sub(now.as_secs())))
645}
646
647#[derive(Debug, Clone, Deserialize)]
651struct ImdsTokenResponse {
652 pub access_token: String,
653 #[serde(deserialize_with = "expires_on_string")]
654 pub expires_on: Instant,
655}
656
657#[derive(Debug)]
662pub struct ImdsManagedIdentityProvider {
663 msi_endpoint: String,
664 client_id: Option<String>,
665 object_id: Option<String>,
666 msi_res_id: Option<String>,
667}
668
669impl ImdsManagedIdentityProvider {
670 pub fn new(
672 client_id: Option<String>,
673 object_id: Option<String>,
674 msi_res_id: Option<String>,
675 msi_endpoint: Option<String>,
676 ) -> Self {
677 let msi_endpoint = msi_endpoint
678 .unwrap_or_else(|| "http://169.254.169.254/metadata/identity/oauth2/token".to_owned());
679
680 Self {
681 msi_endpoint,
682 client_id,
683 object_id,
684 msi_res_id,
685 }
686 }
687}
688
689#[async_trait::async_trait]
690impl TokenProvider for ImdsManagedIdentityProvider {
691 type Credential = AzureCredential;
692
693 async fn fetch_token(
695 &self,
696 client: &Client,
697 retry: &RetryConfig,
698 ) -> crate::Result<TemporaryToken<Arc<AzureCredential>>> {
699 let mut query_items = vec![
700 ("api-version", MSI_API_VERSION),
701 ("resource", AZURE_STORAGE_RESOURCE),
702 ];
703
704 let mut identity = None;
705 if let Some(client_id) = &self.client_id {
706 identity = Some(("client_id", client_id));
707 }
708 if let Some(object_id) = &self.object_id {
709 identity = Some(("object_id", object_id));
710 }
711 if let Some(msi_res_id) = &self.msi_res_id {
712 identity = Some(("msi_res_id", msi_res_id));
713 }
714 if let Some((key, value)) = identity {
715 query_items.push((key, value));
716 }
717
718 let mut builder = client
719 .request(Method::GET, &self.msi_endpoint)
720 .header("metadata", "true")
721 .query(&query_items);
722
723 if let Ok(val) = std::env::var(MSI_SECRET_ENV_KEY) {
724 builder = builder.header("x-identity-header", val);
725 };
726
727 let response: ImdsTokenResponse = builder
728 .send_retry(retry)
729 .await
730 .context(TokenRequestSnafu)?
731 .json()
732 .await
733 .context(TokenResponseBodySnafu)?;
734
735 Ok(TemporaryToken {
736 token: Arc::new(AzureCredential::BearerToken(response.access_token)),
737 expiry: Some(response.expires_on),
738 })
739 }
740}
741
742#[derive(Debug)]
746pub struct WorkloadIdentityOAuthProvider {
747 token_url: String,
748 client_id: String,
749 federated_token_file: String,
750}
751
752impl WorkloadIdentityOAuthProvider {
753 pub fn new(
755 client_id: impl Into<String>,
756 federated_token_file: impl Into<String>,
757 tenant_id: impl AsRef<str>,
758 authority_host: Option<String>,
759 ) -> Self {
760 let authority_host =
761 authority_host.unwrap_or_else(|| authority_hosts::AZURE_PUBLIC_CLOUD.to_owned());
762
763 Self {
764 token_url: format!(
765 "{}/{}/oauth2/v2.0/token",
766 authority_host,
767 tenant_id.as_ref()
768 ),
769 client_id: client_id.into(),
770 federated_token_file: federated_token_file.into(),
771 }
772 }
773}
774
775#[async_trait::async_trait]
776impl TokenProvider for WorkloadIdentityOAuthProvider {
777 type Credential = AzureCredential;
778
779 async fn fetch_token(
781 &self,
782 client: &Client,
783 retry: &RetryConfig,
784 ) -> crate::Result<TemporaryToken<Arc<AzureCredential>>> {
785 let token_str = std::fs::read_to_string(&self.federated_token_file)
786 .map_err(|_| Error::FederatedTokenFile)?;
787
788 let response: OAuthTokenResponse = client
790 .request(Method::POST, &self.token_url)
791 .header(ACCEPT, HeaderValue::from_static(CONTENT_TYPE_JSON))
792 .form(&[
793 ("client_id", self.client_id.as_str()),
794 (
795 "client_assertion_type",
796 "urn:ietf:params:oauth:client-assertion-type:jwt-bearer",
797 ),
798 ("client_assertion", token_str.as_str()),
799 ("scope", AZURE_STORAGE_SCOPE),
800 ("grant_type", "client_credentials"),
801 ])
802 .retryable(retry)
803 .idempotent(true)
804 .send()
805 .await
806 .context(TokenRequestSnafu)?
807 .json()
808 .await
809 .context(TokenResponseBodySnafu)?;
810
811 Ok(TemporaryToken {
812 token: Arc::new(AzureCredential::BearerToken(response.access_token)),
813 expiry: Some(Instant::now() + Duration::from_secs(response.expires_in)),
814 })
815 }
816}
817
818mod az_cli_date_format {
819 use chrono::{DateTime, TimeZone};
820 use serde::{self, Deserialize, Deserializer};
821
822 pub fn deserialize<'de, D>(deserializer: D) -> Result<DateTime<chrono::Local>, D::Error>
823 where
824 D: Deserializer<'de>,
825 {
826 let s = String::deserialize(deserializer)?;
827 let date = chrono::NaiveDateTime::parse_from_str(&s, "%Y-%m-%d %H:%M:%S.%6f")
829 .map_err(serde::de::Error::custom)?;
830 chrono::Local
831 .from_local_datetime(&date)
832 .single()
833 .ok_or(serde::de::Error::custom(
834 "azure cli returned ambiguous expiry date",
835 ))
836 }
837}
838
839#[derive(Debug, Clone, Deserialize)]
840#[serde(rename_all = "camelCase")]
841struct AzureCliTokenResponse {
842 pub access_token: String,
843 #[serde(with = "az_cli_date_format")]
844 pub expires_on: DateTime<chrono::Local>,
845 pub token_type: String,
846}
847
848#[derive(Default, Debug)]
849pub struct AzureCliCredential {
850 cache: TokenCache<Arc<AzureCredential>>,
851}
852
853impl AzureCliCredential {
854 pub fn new() -> Self {
855 Self::default()
856 }
857
858 async fn fetch_token(&self) -> Result<TemporaryToken<Arc<AzureCredential>>> {
860 let program = if cfg!(target_os = "windows") {
863 "cmd"
864 } else {
865 "az"
866 };
867 let mut args = Vec::new();
868 if cfg!(target_os = "windows") {
869 args.push("/C");
870 args.push("az");
871 }
872 args.push("account");
873 args.push("get-access-token");
874 args.push("--output");
875 args.push("json");
876 args.push("--scope");
877 args.push(AZURE_STORAGE_SCOPE);
878
879 match Command::new(program).args(args).output() {
880 Ok(az_output) if az_output.status.success() => {
881 let output = str::from_utf8(&az_output.stdout).map_err(|_| Error::AzureCli {
882 message: "az response is not a valid utf-8 string".to_string(),
883 })?;
884
885 let token_response = serde_json::from_str::<AzureCliTokenResponse>(output)
886 .context(AzureCliResponseSnafu)?;
887 if !token_response.token_type.eq_ignore_ascii_case("bearer") {
888 return Err(Error::AzureCli {
889 message: format!(
890 "got unexpected token type from azure cli: {0}",
891 token_response.token_type
892 ),
893 });
894 }
895 let duration =
896 token_response.expires_on.naive_local() - chrono::Local::now().naive_local();
897 Ok(TemporaryToken {
898 token: Arc::new(AzureCredential::BearerToken(token_response.access_token)),
899 expiry: Some(
900 Instant::now()
901 + duration.to_std().map_err(|_| Error::AzureCli {
902 message: "az returned invalid lifetime".to_string(),
903 })?,
904 ),
905 })
906 }
907 Ok(az_output) => {
908 let message = String::from_utf8_lossy(&az_output.stderr);
909 Err(Error::AzureCli {
910 message: message.into(),
911 })
912 }
913 Err(e) => match e.kind() {
914 std::io::ErrorKind::NotFound => Err(Error::AzureCli {
915 message: "Azure Cli not installed".into(),
916 }),
917 error_kind => Err(Error::AzureCli {
918 message: format!("io error: {error_kind:?}"),
919 }),
920 },
921 }
922 }
923}
924
925#[async_trait]
926impl CredentialProvider for AzureCliCredential {
927 type Credential = AzureCredential;
928
929 async fn get_credential(&self) -> crate::Result<Arc<Self::Credential>> {
930 Ok(self.cache.get_or_insert_with(|| self.fetch_token()).await?)
931 }
932}
933
934#[cfg(test)]
935mod tests {
936 use futures::executor::block_on;
937 use http_body_util::BodyExt;
938 use hyper::{Response, StatusCode};
939 use reqwest::{Client, Method};
940 use tempfile::NamedTempFile;
941
942 use super::*;
943 use crate::azure::MicrosoftAzureBuilder;
944 use crate::client::mock_server::MockServer;
945 use crate::{ObjectStore, Path};
946
947 #[tokio::test]
948 async fn test_managed_identity() {
949 let server = MockServer::new().await;
950
951 std::env::set_var(MSI_SECRET_ENV_KEY, "env-secret");
952
953 let endpoint = server.url();
954 let client = Client::new();
955 let retry_config = RetryConfig::default();
956
957 server.push_fn(|req| {
959 assert_eq!(req.uri().path(), "/metadata/identity/oauth2/token");
960 assert!(req.uri().query().unwrap().contains("client_id=client_id"));
961 assert_eq!(req.method(), &Method::GET);
962 let t = req
963 .headers()
964 .get("x-identity-header")
965 .unwrap()
966 .to_str()
967 .unwrap();
968 assert_eq!(t, "env-secret");
969 let t = req.headers().get("metadata").unwrap().to_str().unwrap();
970 assert_eq!(t, "true");
971 Response::new(
972 r#"
973 {
974 "access_token": "TOKEN",
975 "refresh_token": "",
976 "expires_in": "3599",
977 "expires_on": "1506484173",
978 "not_before": "1506480273",
979 "resource": "https://management.azure.com/",
980 "token_type": "Bearer"
981 }
982 "#
983 .to_string(),
984 )
985 });
986
987 let credential = ImdsManagedIdentityProvider::new(
988 Some("client_id".into()),
989 None,
990 None,
991 Some(format!("{endpoint}/metadata/identity/oauth2/token")),
992 );
993
994 let token = credential
995 .fetch_token(&client, &retry_config)
996 .await
997 .unwrap();
998
999 assert_eq!(
1000 token.token.as_ref(),
1001 &AzureCredential::BearerToken("TOKEN".into())
1002 );
1003 }
1004
1005 #[tokio::test]
1006 async fn test_workload_identity() {
1007 let server = MockServer::new().await;
1008 let tokenfile = NamedTempFile::new().unwrap();
1009 let tenant = "tenant";
1010 std::fs::write(tokenfile.path(), "federated-token").unwrap();
1011
1012 let endpoint = server.url();
1013 let client = Client::new();
1014 let retry_config = RetryConfig::default();
1015
1016 server.push_fn(move |req| {
1018 assert_eq!(req.uri().path(), format!("/{tenant}/oauth2/v2.0/token"));
1019 assert_eq!(req.method(), &Method::POST);
1020 let body = block_on(async move { req.into_body().collect().await.unwrap().to_bytes() });
1021 let body = String::from_utf8(body.to_vec()).unwrap();
1022 assert!(body.contains("federated-token"));
1023 Response::new(
1024 r#"
1025 {
1026 "access_token": "TOKEN",
1027 "refresh_token": "",
1028 "expires_in": 3599,
1029 "expires_on": "1506484173",
1030 "not_before": "1506480273",
1031 "resource": "https://management.azure.com/",
1032 "token_type": "Bearer"
1033 }
1034 "#
1035 .to_string(),
1036 )
1037 });
1038
1039 let credential = WorkloadIdentityOAuthProvider::new(
1040 "client_id",
1041 tokenfile.path().to_str().unwrap(),
1042 tenant,
1043 Some(endpoint.to_string()),
1044 );
1045
1046 let token = credential
1047 .fetch_token(&client, &retry_config)
1048 .await
1049 .unwrap();
1050
1051 assert_eq!(
1052 token.token.as_ref(),
1053 &AzureCredential::BearerToken("TOKEN".into())
1054 );
1055 }
1056
1057 #[tokio::test]
1058 async fn test_no_credentials() {
1059 let server = MockServer::new().await;
1060
1061 let endpoint = server.url();
1062 let store = MicrosoftAzureBuilder::new()
1063 .with_account("test")
1064 .with_container_name("test")
1065 .with_allow_http(true)
1066 .with_bearer_token_authorization("token")
1067 .with_endpoint(endpoint.to_string())
1068 .with_skip_signature(true)
1069 .build()
1070 .unwrap();
1071
1072 server.push_fn(|req| {
1073 assert_eq!(req.method(), &Method::GET);
1074 assert!(req.headers().get("Authorization").is_none());
1075 Response::builder()
1076 .status(StatusCode::NOT_FOUND)
1077 .body("not found".to_string())
1078 .unwrap()
1079 });
1080
1081 let path = Path::from("file.txt");
1082 match store.get(&path).await {
1083 Err(crate::Error::NotFound { .. }) => {}
1084 _ => {
1085 panic!("unexpected response");
1086 }
1087 }
1088 }
1089}