snix_castore/blobservice/
memory.rs

1use parking_lot::RwLock;
2use std::io::{self, Cursor, Write};
3use std::task::Poll;
4use std::{collections::HashMap, sync::Arc};
5use tonic::async_trait;
6use tracing::instrument;
7
8use super::{BlobReader, BlobService, BlobWriter};
9use crate::composition::{CompositionContext, ServiceBuilder};
10use crate::{B3Digest, Error};
11
12#[derive(Clone, Default)]
13pub struct MemoryBlobService {
14    instance_name: String,
15    db: Arc<RwLock<HashMap<B3Digest, Vec<u8>>>>,
16}
17
18#[async_trait]
19impl BlobService for MemoryBlobService {
20    #[instrument(skip_all, ret, err, fields(blob.digest=%digest, instance_name=%self.instance_name))]
21    async fn has(&self, digest: &B3Digest) -> io::Result<bool> {
22        let db = self.db.read();
23        Ok(db.contains_key(digest))
24    }
25
26    #[instrument(skip_all, err, fields(blob.digest=%digest, instance_name=%self.instance_name))]
27    async fn open_read(&self, digest: &B3Digest) -> io::Result<Option<Box<dyn BlobReader>>> {
28        let db = self.db.read();
29
30        match db.get(digest).map(|x| Cursor::new(x.clone())) {
31            Some(result) => Ok(Some(Box::new(result))),
32            None => Ok(None),
33        }
34    }
35
36    #[instrument(skip_all, fields(instance_name=%self.instance_name))]
37    async fn open_write(&self) -> Box<dyn BlobWriter> {
38        Box::new(MemoryBlobWriter::new(self.db.clone()))
39    }
40}
41
42#[derive(serde::Deserialize, Debug)]
43#[serde(deny_unknown_fields)]
44pub struct MemoryBlobServiceConfig {}
45
46impl TryFrom<url::Url> for MemoryBlobServiceConfig {
47    type Error = Box<dyn std::error::Error + Send + Sync>;
48    fn try_from(url: url::Url) -> Result<Self, Self::Error> {
49        // memory doesn't support host or path in the URL.
50        if url.has_host() || !url.path().is_empty() {
51            return Err(Error::StorageError("invalid url".to_string()).into());
52        }
53        Ok(MemoryBlobServiceConfig {})
54    }
55}
56
57#[async_trait]
58impl ServiceBuilder for MemoryBlobServiceConfig {
59    type Output = dyn BlobService;
60    async fn build<'a>(
61        &'a self,
62        instance_name: &str,
63        _context: &CompositionContext,
64    ) -> Result<Arc<dyn BlobService>, Box<dyn std::error::Error + Send + Sync + 'static>> {
65        Ok(Arc::new(MemoryBlobService {
66            instance_name: instance_name.to_string(),
67            db: Default::default(),
68        }))
69    }
70}
71
72pub struct MemoryBlobWriter {
73    db: Arc<RwLock<HashMap<B3Digest, Vec<u8>>>>,
74
75    /// Contains the buffer Vec and hasher, or None if already closed
76    writers: Option<(Vec<u8>, blake3::Hasher)>,
77
78    /// The digest that has been returned, if we successfully closed.
79    digest: Option<B3Digest>,
80}
81
82impl MemoryBlobWriter {
83    fn new(db: Arc<RwLock<HashMap<B3Digest, Vec<u8>>>>) -> Self {
84        Self {
85            db,
86            writers: Some((Vec::new(), blake3::Hasher::new())),
87            digest: None,
88        }
89    }
90}
91impl tokio::io::AsyncWrite for MemoryBlobWriter {
92    fn poll_write(
93        mut self: std::pin::Pin<&mut Self>,
94        _cx: &mut std::task::Context<'_>,
95        b: &[u8],
96    ) -> std::task::Poll<Result<usize, io::Error>> {
97        Poll::Ready(match &mut self.writers {
98            None => Err(io::Error::new(
99                io::ErrorKind::NotConnected,
100                "already closed",
101            )),
102            Some((buf, hasher)) => {
103                let bytes_written = buf.write(b)?;
104                hasher.write(&b[..bytes_written])
105            }
106        })
107    }
108
109    fn poll_flush(
110        self: std::pin::Pin<&mut Self>,
111        _cx: &mut std::task::Context<'_>,
112    ) -> std::task::Poll<Result<(), io::Error>> {
113        Poll::Ready(match self.writers {
114            None => Err(io::Error::new(
115                io::ErrorKind::NotConnected,
116                "already closed",
117            )),
118            Some(_) => Ok(()),
119        })
120    }
121
122    fn poll_shutdown(
123        self: std::pin::Pin<&mut Self>,
124        _cx: &mut std::task::Context<'_>,
125    ) -> std::task::Poll<Result<(), io::Error>> {
126        // shutdown is "instantaneous", we only write to memory.
127        Poll::Ready(Ok(()))
128    }
129}
130
131#[async_trait]
132impl BlobWriter for MemoryBlobWriter {
133    async fn close(&mut self) -> io::Result<B3Digest> {
134        if self.writers.is_none() {
135            match &self.digest {
136                Some(digest) => Ok(digest.clone()),
137                None => Err(io::Error::new(io::ErrorKind::BrokenPipe, "already closed")),
138            }
139        } else {
140            let (buf, hasher) = self.writers.take().unwrap();
141
142            let digest: B3Digest = hasher.finalize().as_bytes().into();
143
144            // Only insert if the blob doesn't already exist.
145            let mut db = self.db.upgradable_read();
146            if !db.contains_key(&digest) {
147                // open the database for writing.
148                db.with_upgraded(|db| {
149                    // and put buf in there. This will move buf out.
150                    db.insert(digest.clone(), buf);
151                });
152            }
153
154            self.digest = Some(digest.clone());
155
156            Ok(digest)
157        }
158    }
159}