snix_castore/blobservice/
memory.rs1use parking_lot::RwLock;
2use std::io::{self, Cursor, Write};
3use std::task::Poll;
4use std::{collections::HashMap, sync::Arc};
5use tonic::async_trait;
6use tracing::instrument;
7
8use super::{BlobReader, BlobService, BlobWriter};
9use crate::composition::{CompositionContext, ServiceBuilder};
10use crate::{B3Digest, Error};
11
12#[derive(Clone, Default)]
13pub struct MemoryBlobService {
14 instance_name: String,
15 db: Arc<RwLock<HashMap<B3Digest, Vec<u8>>>>,
16}
17
18#[async_trait]
19impl BlobService for MemoryBlobService {
20 #[instrument(skip_all, ret, err, fields(blob.digest=%digest, instance_name=%self.instance_name))]
21 async fn has(&self, digest: &B3Digest) -> io::Result<bool> {
22 let db = self.db.read();
23 Ok(db.contains_key(digest))
24 }
25
26 #[instrument(skip_all, err, fields(blob.digest=%digest, instance_name=%self.instance_name))]
27 async fn open_read(&self, digest: &B3Digest) -> io::Result<Option<Box<dyn BlobReader>>> {
28 let db = self.db.read();
29
30 match db.get(digest).map(|x| Cursor::new(x.clone())) {
31 Some(result) => Ok(Some(Box::new(result))),
32 None => Ok(None),
33 }
34 }
35
36 #[instrument(skip_all, fields(instance_name=%self.instance_name))]
37 async fn open_write(&self) -> Box<dyn BlobWriter> {
38 Box::new(MemoryBlobWriter::new(self.db.clone()))
39 }
40}
41
42#[derive(serde::Deserialize, Debug)]
43#[serde(deny_unknown_fields)]
44pub struct MemoryBlobServiceConfig {}
45
46impl TryFrom<url::Url> for MemoryBlobServiceConfig {
47 type Error = Box<dyn std::error::Error + Send + Sync>;
48 fn try_from(url: url::Url) -> Result<Self, Self::Error> {
49 if url.has_host() || !url.path().is_empty() {
51 return Err(Error::StorageError("invalid url".to_string()).into());
52 }
53 Ok(MemoryBlobServiceConfig {})
54 }
55}
56
57#[async_trait]
58impl ServiceBuilder for MemoryBlobServiceConfig {
59 type Output = dyn BlobService;
60 async fn build<'a>(
61 &'a self,
62 instance_name: &str,
63 _context: &CompositionContext,
64 ) -> Result<Arc<dyn BlobService>, Box<dyn std::error::Error + Send + Sync + 'static>> {
65 Ok(Arc::new(MemoryBlobService {
66 instance_name: instance_name.to_string(),
67 db: Default::default(),
68 }))
69 }
70}
71
72pub struct MemoryBlobWriter {
73 db: Arc<RwLock<HashMap<B3Digest, Vec<u8>>>>,
74
75 writers: Option<(Vec<u8>, blake3::Hasher)>,
77
78 digest: Option<B3Digest>,
80}
81
82impl MemoryBlobWriter {
83 fn new(db: Arc<RwLock<HashMap<B3Digest, Vec<u8>>>>) -> Self {
84 Self {
85 db,
86 writers: Some((Vec::new(), blake3::Hasher::new())),
87 digest: None,
88 }
89 }
90}
91impl tokio::io::AsyncWrite for MemoryBlobWriter {
92 fn poll_write(
93 mut self: std::pin::Pin<&mut Self>,
94 _cx: &mut std::task::Context<'_>,
95 b: &[u8],
96 ) -> std::task::Poll<Result<usize, io::Error>> {
97 Poll::Ready(match &mut self.writers {
98 None => Err(io::Error::new(
99 io::ErrorKind::NotConnected,
100 "already closed",
101 )),
102 Some((buf, hasher)) => {
103 let bytes_written = buf.write(b)?;
104 hasher.write(&b[..bytes_written])
105 }
106 })
107 }
108
109 fn poll_flush(
110 self: std::pin::Pin<&mut Self>,
111 _cx: &mut std::task::Context<'_>,
112 ) -> std::task::Poll<Result<(), io::Error>> {
113 Poll::Ready(match self.writers {
114 None => Err(io::Error::new(
115 io::ErrorKind::NotConnected,
116 "already closed",
117 )),
118 Some(_) => Ok(()),
119 })
120 }
121
122 fn poll_shutdown(
123 self: std::pin::Pin<&mut Self>,
124 _cx: &mut std::task::Context<'_>,
125 ) -> std::task::Poll<Result<(), io::Error>> {
126 Poll::Ready(Ok(()))
128 }
129}
130
131#[async_trait]
132impl BlobWriter for MemoryBlobWriter {
133 async fn close(&mut self) -> io::Result<B3Digest> {
134 if self.writers.is_none() {
135 match &self.digest {
136 Some(digest) => Ok(digest.clone()),
137 None => Err(io::Error::new(io::ErrorKind::BrokenPipe, "already closed")),
138 }
139 } else {
140 let (buf, hasher) = self.writers.take().unwrap();
141
142 let digest: B3Digest = hasher.finalize().as_bytes().into();
143
144 let mut db = self.db.upgradable_read();
146 if !db.contains_key(&digest) {
147 db.with_upgraded(|db| {
149 db.insert(digest.clone(), buf);
151 });
152 }
153
154 self.digest = Some(digest.clone());
155
156 Ok(digest)
157 }
158 }
159}