1use std::borrow::Cow;
21use std::collections::HashMap;
22use std::future::Future;
23use std::time::{Duration, Instant};
24
25use chrono::Utc;
26use reqwest::{Response, StatusCode};
27use serde::ser::SerializeMap;
28use serde::{Deserialize, Serialize, Serializer};
29
30use crate::aws::client::S3Client;
31use crate::aws::credential::CredentialExt;
32use crate::aws::{AwsAuthorizer, AwsCredential};
33use crate::client::get::GetClientExt;
34use crate::client::retry::Error as RetryError;
35use crate::client::retry::RetryExt;
36use crate::path::Path;
37use crate::{Error, GetOptions, Result};
38
39const CONFLICT: &str = "ConditionalCheckFailedException";
41
42const STORE: &str = "DynamoDB";
43
44#[derive(Debug, Clone, Eq, PartialEq)]
111pub struct DynamoCommit {
112 table_name: String,
113 timeout: u64,
115 max_clock_skew_rate: u32,
117 ttl: Duration,
122 test_interval: Duration,
124}
125
126impl DynamoCommit {
127 pub fn new(table_name: String) -> Self {
129 Self {
130 table_name,
131 timeout: 20_000,
132 max_clock_skew_rate: 3,
133 ttl: Duration::from_secs(60 * 60),
134 test_interval: Duration::from_millis(100),
135 }
136 }
137
138 pub fn with_timeout(mut self, millis: u64) -> Self {
144 self.timeout = millis;
145 self
146 }
147
148 pub fn with_max_clock_skew_rate(mut self, rate: u32) -> Self {
154 self.max_clock_skew_rate = rate;
155 self
156 }
157
158 pub fn with_ttl(mut self, ttl: Duration) -> Self {
163 self.ttl = ttl;
164 self
165 }
166
167 pub(crate) fn from_str(value: &str) -> Option<Self> {
169 Some(match value.split_once(':') {
170 Some((table_name, timeout)) => {
171 Self::new(table_name.trim().to_string()).with_timeout(timeout.parse().ok()?)
172 }
173 None => Self::new(value.trim().to_string()),
174 })
175 }
176
177 pub(crate) fn table_name(&self) -> &str {
179 &self.table_name
180 }
181
182 pub(crate) async fn copy_if_not_exists(
183 &self,
184 client: &S3Client,
185 from: &Path,
186 to: &Path,
187 ) -> Result<()> {
188 self.conditional_op(client, to, None, || async {
189 client.copy_request(from, to).send().await?;
190 Ok(())
191 })
192 .await
193 }
194
195 #[allow(clippy::future_not_send)] pub(crate) async fn conditional_op<F, Fut, T>(
197 &self,
198 client: &S3Client,
199 to: &Path,
200 etag: Option<&str>,
201 op: F,
202 ) -> Result<T>
203 where
204 F: FnOnce() -> Fut,
205 Fut: Future<Output = Result<T, Error>>,
206 {
207 check_precondition(client, to, etag).await?;
208
209 let mut previous_lease = None;
210
211 loop {
212 let existing = previous_lease.as_ref();
213 match self.try_lock(client, to.as_ref(), etag, existing).await? {
214 TryLockResult::Ok(lease) => {
215 let expiry = lease.acquire + lease.timeout;
216 return match tokio::time::timeout_at(expiry.into(), op()).await {
217 Ok(Ok(v)) => Ok(v),
218 Ok(Err(e)) => Err(e),
219 Err(_) => Err(Error::Generic {
220 store: "DynamoDB",
221 source: format!(
222 "Failed to perform conditional operation in {} milliseconds",
223 self.timeout
224 )
225 .into(),
226 }),
227 };
228 }
229 TryLockResult::Conflict(conflict) => {
230 let mut interval = tokio::time::interval(self.test_interval);
231 let expiry = conflict.timeout * self.max_clock_skew_rate;
232 loop {
233 interval.tick().await;
234 check_precondition(client, to, etag).await?;
235 if conflict.acquire.elapsed() > expiry {
236 previous_lease = Some(conflict);
237 break;
238 }
239 }
240 }
241 }
242 }
243 }
244
245 async fn try_lock(
247 &self,
248 s3: &S3Client,
249 path: &str,
250 etag: Option<&str>,
251 existing: Option<&Lease>,
252 ) -> Result<TryLockResult> {
253 let attributes;
254 let (next_gen, condition_expression, expression_attribute_values) = match existing {
255 None => (0_u64, "attribute_not_exists(#pk)", Map(&[])),
256 Some(existing) => {
257 attributes = [(":g", AttributeValue::Number(existing.generation))];
258 (
259 existing.generation.checked_add(1).unwrap(),
260 "attribute_exists(#pk) AND generation = :g",
261 Map(attributes.as_slice()),
262 )
263 }
264 };
265
266 let ttl = (Utc::now() + self.ttl).timestamp();
267 let items = [
268 ("path", AttributeValue::from(path)),
269 ("etag", AttributeValue::from(etag.unwrap_or("*"))),
270 ("generation", AttributeValue::Number(next_gen)),
271 ("timeout", AttributeValue::Number(self.timeout)),
272 ("ttl", AttributeValue::Number(ttl as _)),
273 ];
274 let names = [("#pk", "path")];
275
276 let req = PutItem {
277 table_name: &self.table_name,
278 condition_expression,
279 expression_attribute_values,
280 expression_attribute_names: Map(&names),
281 item: Map(&items),
282 return_values: None,
283 return_values_on_condition_check_failure: Some(ReturnValues::AllOld),
284 };
285
286 let credential = s3.config.get_credential().await?;
287
288 let acquire = Instant::now();
289 match self
290 .request(s3, credential.as_deref(), "DynamoDB_20120810.PutItem", req)
291 .await
292 {
293 Ok(_) => Ok(TryLockResult::Ok(Lease {
294 acquire,
295 generation: next_gen,
296 timeout: Duration::from_millis(self.timeout),
297 })),
298 Err(e) => match parse_error_response(&e) {
299 Some(e) if e.error.ends_with(CONFLICT) => match extract_lease(&e.item) {
300 Some(lease) => Ok(TryLockResult::Conflict(lease)),
301 None => Err(Error::Generic {
302 store: STORE,
303 source: "Failed to extract lease from conflict ReturnValuesOnConditionCheckFailure response".into()
304 }),
305 },
306 _ => Err(Error::Generic {
307 store: STORE,
308 source: Box::new(e),
309 }),
310 },
311 }
312 }
313
314 async fn request<R: Serialize + Send + Sync>(
315 &self,
316 s3: &S3Client,
317 cred: Option<&AwsCredential>,
318 target: &str,
319 req: R,
320 ) -> Result<Response, RetryError> {
321 let region = &s3.config.region;
322 let authorizer = cred.map(|x| AwsAuthorizer::new(x, "dynamodb", region));
323
324 let builder = match &s3.config.endpoint {
325 Some(e) => s3.client.post(e),
326 None => {
327 let url = format!("https://dynamodb.{region}.amazonaws.com");
328 s3.client.post(url)
329 }
330 };
331
332 builder
333 .timeout(Duration::from_millis(self.timeout))
334 .json(&req)
335 .header("X-Amz-Target", target)
336 .with_aws_sigv4(authorizer, None)
337 .send_retry(&s3.config.retry_config)
338 .await
339 }
340}
341
342#[derive(Debug)]
343enum TryLockResult {
344 Ok(Lease),
346 Conflict(Lease),
348}
349
350async fn check_precondition(client: &S3Client, path: &Path, etag: Option<&str>) -> Result<()> {
352 let options = GetOptions {
353 head: true,
354 ..Default::default()
355 };
356
357 match etag {
358 Some(expected) => match client.get_opts(path, options).await {
359 Ok(r) => match r.meta.e_tag {
360 Some(actual) if expected == actual => Ok(()),
361 actual => Err(Error::Precondition {
362 path: path.to_string(),
363 source: format!("{} does not match {expected}", actual.unwrap_or_default())
364 .into(),
365 }),
366 },
367 Err(Error::NotFound { .. }) => Err(Error::Precondition {
368 path: path.to_string(),
369 source: format!("Object at location {path} not found").into(),
370 }),
371 Err(e) => Err(e),
372 },
373 None => match client.get_opts(path, options).await {
374 Ok(_) => Err(Error::AlreadyExists {
375 path: path.to_string(),
376 source: "Already Exists".to_string().into(),
377 }),
378 Err(Error::NotFound { .. }) => Ok(()),
379 Err(e) => Err(e),
380 },
381 }
382}
383
384fn parse_error_response(e: &RetryError) -> Option<ErrorResponse<'_>> {
386 match e {
387 RetryError::Client {
388 status: StatusCode::BAD_REQUEST,
389 body: Some(b),
390 } => serde_json::from_str(b).ok(),
391 _ => None,
392 }
393}
394
395fn extract_lease(item: &HashMap<&str, AttributeValue<'_>>) -> Option<Lease> {
397 let generation = match item.get("generation") {
398 Some(AttributeValue::Number(generation)) => generation,
399 _ => return None,
400 };
401
402 let timeout = match item.get("timeout") {
403 Some(AttributeValue::Number(timeout)) => *timeout,
404 _ => return None,
405 };
406
407 Some(Lease {
408 acquire: Instant::now(),
409 generation: *generation,
410 timeout: Duration::from_millis(timeout),
411 })
412}
413
414#[derive(Debug, Clone)]
416struct Lease {
417 acquire: Instant,
418 generation: u64,
419 timeout: Duration,
420}
421
422#[derive(Serialize)]
426#[serde(rename_all = "PascalCase")]
427struct PutItem<'a> {
428 table_name: &'a str,
430
431 condition_expression: &'a str,
433
434 expression_attribute_names: Map<'a, &'a str, &'a str>,
436
437 expression_attribute_values: Map<'a, &'a str, AttributeValue<'a>>,
439
440 item: Map<'a, &'a str, AttributeValue<'a>>,
442
443 #[serde(skip_serializing_if = "Option::is_none")]
446 return_values: Option<ReturnValues>,
447
448 #[serde(skip_serializing_if = "Option::is_none")]
451 return_values_on_condition_check_failure: Option<ReturnValues>,
452}
453
454#[derive(Deserialize)]
455struct ErrorResponse<'a> {
456 #[serde(rename = "__type")]
457 error: &'a str,
458
459 #[serde(borrow, default, rename = "Item")]
460 item: HashMap<&'a str, AttributeValue<'a>>,
461}
462
463#[derive(Serialize)]
464#[serde(rename_all = "SCREAMING_SNAKE_CASE")]
465enum ReturnValues {
466 AllOld,
467}
468
469struct Map<'a, K, V>(&'a [(K, V)]);
473
474impl<'a, K: Serialize, V: Serialize> Serialize for Map<'a, K, V> {
475 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
476 where
477 S: Serializer,
478 {
479 if self.0.is_empty() {
480 return serializer.serialize_none();
481 }
482 let mut map = serializer.serialize_map(Some(self.0.len()))?;
483 for (k, v) in self.0 {
484 map.serialize_entry(k, v)?
485 }
486 map.end()
487 }
488}
489
490#[derive(Debug, Serialize, Deserialize)]
494enum AttributeValue<'a> {
495 #[serde(rename = "S")]
496 String(Cow<'a, str>),
497 #[serde(rename = "N", with = "number")]
498 Number(u64),
499}
500
501impl<'a> From<&'a str> for AttributeValue<'a> {
502 fn from(value: &'a str) -> Self {
503 Self::String(Cow::Borrowed(value))
504 }
505}
506
507mod number {
509 use serde::{Deserialize, Deserializer, Serializer};
510
511 pub fn serialize<S: Serializer>(v: &u64, s: S) -> Result<S::Ok, S::Error> {
512 s.serialize_str(&v.to_string())
513 }
514
515 pub fn deserialize<'de, D: Deserializer<'de>>(d: D) -> Result<u64, D::Error> {
516 let v: &str = Deserialize::deserialize(d)?;
517 v.parse().map_err(serde::de::Error::custom)
518 }
519}
520
521#[cfg(test)]
523pub(crate) use tests::integration_test;
524
525#[cfg(test)]
526mod tests {
527 use super::*;
528 use crate::aws::AmazonS3;
529 use crate::ObjectStore;
530 use rand::distributions::Alphanumeric;
531 use rand::{thread_rng, Rng};
532
533 #[test]
534 fn test_attribute_serde() {
535 let serde = serde_json::to_string(&AttributeValue::Number(23)).unwrap();
536 assert_eq!(serde, "{\"N\":\"23\"}");
537 let back: AttributeValue<'_> = serde_json::from_str(&serde).unwrap();
538 assert!(matches!(back, AttributeValue::Number(23)));
539 }
540
541 pub async fn integration_test(integration: &AmazonS3, d: &DynamoCommit) {
545 let client = integration.client.as_ref();
546
547 let src = Path::from("dynamo_path_src");
548 integration.put(&src, "asd".into()).await.unwrap();
549
550 let dst = Path::from("dynamo_path");
551 let _ = integration.delete(&dst).await; let existing = match d.try_lock(client, dst.as_ref(), None, None).await.unwrap() {
555 TryLockResult::Conflict(l) => l,
556 TryLockResult::Ok(l) => l,
557 };
558
559 let r = d.try_lock(client, dst.as_ref(), None, None).await;
561 assert!(matches!(r, Ok(TryLockResult::Conflict(_))));
562
563 d.copy_if_not_exists(client, &src, &dst).await.unwrap();
565
566 match d.try_lock(client, dst.as_ref(), None, None).await.unwrap() {
567 TryLockResult::Conflict(new) => {
568 assert_eq!(new.generation, existing.generation + 1);
570 }
571 _ => panic!("Should conflict"),
572 }
573
574 let rng = thread_rng();
575 let etag = String::from_utf8(rng.sample_iter(Alphanumeric).take(32).collect()).unwrap();
576 let t = Some(etag.as_str());
577
578 let l = match d.try_lock(client, dst.as_ref(), t, None).await.unwrap() {
579 TryLockResult::Ok(l) => l,
580 _ => panic!("should not conflict"),
581 };
582
583 match d.try_lock(client, dst.as_ref(), t, None).await.unwrap() {
584 TryLockResult::Conflict(c) => assert_eq!(l.generation, c.generation),
585 _ => panic!("should conflict"),
586 }
587
588 match d.try_lock(client, dst.as_ref(), t, Some(&l)).await.unwrap() {
589 TryLockResult::Ok(new) => assert_eq!(new.generation, l.generation + 1),
590 _ => panic!("should not conflict"),
591 }
592 }
593}