snix_castore/blobservice/object_store/
aws.rs

1//! Custom AWS logic to configure the AWS [object_store].
2//! Upstream doesn't support the AWS credential chain.
3
4use std::sync::{Arc, RwLock};
5use tonic::async_trait;
6
7/// Wrapper type implementing [object_store::CredentialProvider]
8/// for [aws_credential_types::provider::ProvideCredentials].
9#[derive(Debug)]
10struct AwsConfigCredentialProvider<T: aws_credential_types::provider::ProvideCredentials> {
11    inner: T,
12
13    /// Retrieved credentials, alongside their expiry time, if any.
14    credential: std::sync::RwLock<
15        Option<(
16            Arc<object_store::aws::AwsCredential>,
17            Option<std::time::SystemTime>,
18        )>,
19    >,
20}
21
22impl<T> AwsConfigCredentialProvider<T>
23where
24    T: aws_credential_types::provider::ProvideCredentials,
25{
26    pub fn new(aws_credential_provider: T) -> Self {
27        Self {
28            inner: aws_credential_provider,
29            credential: RwLock::new(None),
30        }
31    }
32
33    async fn get_new_creds(
34        &self,
35    ) -> object_store::Result<(
36        Arc<object_store::aws::AwsCredential>,
37        Option<std::time::SystemTime>,
38    )> {
39        let credentials =
40            self.inner
41                .provide_credentials()
42                .await
43                .map_err(|err| object_store::Error::Generic {
44                    store: "S3",
45                    source: Box::new(err),
46                })?;
47
48        let object_store_credentials = Arc::new(object_store::aws::AwsCredential {
49            key_id: credentials.access_key_id().to_owned(),
50            secret_key: credentials.secret_access_key().to_owned(),
51            token: credentials.session_token().map(|s| s.to_owned()),
52        });
53        let expiry = credentials.expiry();
54
55        Ok((object_store_credentials, expiry))
56    }
57}
58
59#[async_trait]
60impl<T> object_store::CredentialProvider for AwsConfigCredentialProvider<T>
61where
62    T: aws_credential_types::provider::ProvideCredentials,
63{
64    type Credential = object_store::aws::AwsCredential;
65
66    async fn get_credential(&self) -> object_store::Result<Arc<Self::Credential>> {
67        if let Some(credential) = self.credential.read().expect("poisoned").as_ref() {
68            match credential.1 {
69                // creds expired, renewal below
70                Some(expiry) if expiry <= std::time::SystemTime::now() => {}
71                // not yet expired, or no expiry
72                Some(_) | None => {
73                    return Ok(credential.0.clone());
74                }
75            }
76        }
77
78        // get new creds
79        let (object_store_credential, expiry) = self.get_new_creds().await?;
80
81        let mut c = self.credential.write().expect("poisoned");
82        *c = Some((object_store_credential.clone(), expiry));
83
84        Ok(object_store_credential)
85    }
86}
87
88/// Returns an AWS object store. Contrary to [object_store::parse_url_opts],
89/// this one honors the AWS credential chain. It does only support s3:// URLs.
90///
91/// For opts, only the following keys are allowed:
92///
93///  - aws_access_key_id
94///  - aws_secret_access_key
95///  - aws_region
96///  - aws_allow_http
97///  - aws_endpoint_url
98///  - aws_profile
99///  - user_agent
100///
101/// These keys will take priority over anything discovered via the AWS default
102/// credential chain.
103pub async fn setup_aws_object_store<'a, KV>(
104    url: &url::Url,
105    opts: KV,
106) -> Result<object_store::aws::AmazonS3, Box<dyn std::error::Error + Send + Sync + 'static>>
107where
108    KV: IntoIterator<Item = (&'a str, &'a str)>,
109{
110    let bucket_name = url
111        .host_str()
112        .ok_or_else(|| Box::new(std::io::Error::other("no bucket name set")))?;
113
114    // The AWS SDK config loader.
115    let mut config_loader = aws_config::from_env();
116
117    let mut aws_access_key_id = None;
118    let mut aws_secret_access_key = None;
119    let mut allow_http = false;
120    let mut user_agent = None;
121
122    for (k, v) in opts.into_iter() {
123        match k {
124            "aws_access_key_id" => {
125                aws_access_key_id = Some(v);
126            }
127            "aws_secret_access_key" => {
128                aws_secret_access_key = Some(v);
129            }
130            "aws_region" => {
131                config_loader = config_loader.region(aws_config::Region::new(v.to_owned()));
132            }
133            "aws_allow_http" => {
134                if v == "1" || v == "true" {
135                    allow_http = true;
136                }
137            }
138            "aws_endpoint_url" => {
139                config_loader = config_loader.endpoint_url(v);
140            }
141            "aws_profile" => {
142                config_loader = config_loader.profile_name(v);
143            }
144            "user_agent" => {
145                user_agent = Some(v);
146            }
147            _ => {
148                return Err(Box::new(std::io::Error::new(
149                    std::io::ErrorKind::InvalidInput,
150                    format!("unexpected param: {}", k),
151                )));
152            }
153        }
154    }
155
156    match (aws_access_key_id, aws_secret_access_key) {
157        (None, None) => {}
158        (None, Some(_)) | (Some(_), None) => {
159            return Err(Box::new(std::io::Error::new(
160                std::io::ErrorKind::InvalidInput,
161                "specified only one of `aws_access_key_id`, `aws_secret_access_key`, need to specify both or none",
162            )));
163        }
164        (Some(aws_access_key_id), Some(aws_secret_access_key)) => {
165            config_loader =
166                config_loader.credentials_provider(aws_credential_types::Credentials::new(
167                    aws_access_key_id,
168                    aws_secret_access_key,
169                    None,
170                    None,
171                    "url-params",
172                ));
173        }
174    }
175
176    // FUTUREWORK: can we split this out to make things more testable?
177    let sdk_config = config_loader.load().await;
178
179    let sdk_credentials_provider = sdk_config.credentials_provider().ok_or_else(|| {
180        Box::new(std::io::Error::new(
181            std::io::ErrorKind::PermissionDenied,
182            "couldn't discover AWS credential provider",
183        ))
184    })?;
185
186    let mut store_builder = object_store::aws::AmazonS3Builder::new()
187        .with_credentials(Arc::new(AwsConfigCredentialProvider::new(
188            sdk_credentials_provider,
189        )))
190        .with_bucket_name(bucket_name)
191        .with_allow_http(allow_http)
192        .with_client_options({
193            let mut client_options = object_store::ClientOptions::new().with_allow_http(allow_http);
194
195            if let Some(user_agent) = user_agent {
196                client_options = client_options
197                    .with_user_agent(object_store::HeaderValue::from_str(user_agent)?);
198            }
199
200            client_options
201        });
202
203    if let Some(region) = sdk_config.region() {
204        store_builder = store_builder.with_region(region.to_string());
205    }
206
207    if let Some(endpoint_url) = sdk_config.endpoint_url() {
208        store_builder = store_builder.with_endpoint(endpoint_url);
209    }
210
211    Ok(store_builder.build()?)
212}