snix_castore/blobservice/
grpc.rs1use super::{BlobReader, BlobService, BlobWriter, ChunkedReader};
2use crate::composition::{CompositionContext, ServiceBuilder};
3use crate::{
4 B3Digest,
5 proto::{self, stat_blob_response::ChunkMeta},
6};
7use futures::sink::SinkExt;
8use std::{
9 io::{self, Cursor},
10 pin::pin,
11 sync::Arc,
12 task::Poll,
13};
14use tokio::io::AsyncWriteExt;
15use tokio::task::JoinHandle;
16use tokio_stream::{StreamExt, wrappers::ReceiverStream};
17use tokio_util::{
18 io::{CopyToBytes, SinkWriter},
19 sync::PollSender,
20};
21use tonic::{Code, Status, async_trait};
22use tracing::{Instrument, Span, instrument};
23
24#[derive(Clone)]
26pub struct GRPCBlobService<T> {
27 instance_name: String,
28 grpc_client: proto::blob_service_client::BlobServiceClient<T>,
31}
32
33impl<T> GRPCBlobService<T> {
34 pub fn from_client(
36 instance_name: String,
37 grpc_client: proto::blob_service_client::BlobServiceClient<T>,
38 ) -> Self {
39 Self {
40 instance_name,
41 grpc_client,
42 }
43 }
44}
45
46#[async_trait]
47impl<T> BlobService for GRPCBlobService<T>
48where
49 T: tonic::client::GrpcService<tonic::body::Body> + Send + Sync + Clone + 'static,
50 T::ResponseBody: tonic::codegen::Body<Data = tonic::codegen::Bytes> + Send + 'static,
51 <T::ResponseBody as tonic::codegen::Body>::Error: Into<tonic::codegen::StdError> + Send,
52 T::Future: Send,
53{
54 #[instrument(skip(self, digest), fields(blob.digest=%digest, instance_name=%self.instance_name))]
55 async fn has(&self, digest: &B3Digest) -> io::Result<bool> {
56 match self
57 .grpc_client
58 .clone()
59 .stat(proto::StatBlobRequest {
60 digest: (*digest).into(),
61 ..Default::default()
62 })
63 .await
64 {
65 Ok(_blob_meta) => Ok(true),
66 Err(e) if e.code() == Code::NotFound => Ok(false),
67 Err(e) => Err(io::Error::other(e)),
68 }
69 }
70
71 #[instrument(skip(self, digest), fields(blob.digest=%digest, instance_name=%self.instance_name), err)]
72 async fn open_read(&self, digest: &B3Digest) -> io::Result<Option<Box<dyn BlobReader>>> {
73 match self.chunks(digest).await {
79 Ok(None) => Ok(None),
80 Ok(Some(chunks)) => {
81 if chunks.is_empty() || chunks.len() == 1 {
82 return match self
86 .grpc_client
87 .clone()
88 .read(proto::ReadBlobRequest {
89 digest: (*digest).into(),
90 })
91 .await
92 {
93 Ok(stream) => {
94 let data_stream = stream.into_inner().map(|e| {
95 e.map(|c| c.data)
96 .map_err(|s| std::io::Error::new(io::ErrorKind::InvalidData, s))
97 });
98
99 let mut data_reader = tokio_util::io::StreamReader::new(data_stream);
101
102 let mut buf = Vec::new();
103 tokio::io::copy(&mut data_reader, &mut buf).await?;
105
106 Ok(Some(Box::new(Cursor::new(buf))))
107 }
108 Err(e) if e.code() == Code::NotFound => Ok(None),
109 Err(e) => Err(io::Error::other(e)),
110 };
111 }
112
113 let chunked_reader = ChunkedReader::from_chunks(
118 chunks.into_iter().map(|chunk| {
119 (
120 chunk.digest.try_into().expect("invalid b3 digest"),
121 chunk.size,
122 )
123 }),
124 Arc::new(self.clone()) as Arc<dyn BlobService>,
125 );
126 Ok(Some(Box::new(chunked_reader)))
127 }
128 Err(e) => Err(e)?,
129 }
130 }
131
132 #[instrument(skip_all, fields(instance_name=%self.instance_name))]
135 async fn open_write(&self) -> Box<dyn BlobWriter> {
136 let (tx, rx) = tokio::sync::mpsc::channel::<bytes::Bytes>(64);
138
139 let span = Span::current();
142 let blobchunk_stream = ReceiverStream::new(rx).map(move |x| {
143 let span = tracing::trace_span!(
144 parent: &span,
145 "blob_chunk",
146 blob.size = x.len()
147 );
148
149 span.in_scope(|| {
150 tracing::trace!("constructing BlobChunk");
151 });
152
153 proto::BlobChunk { data: x }
154 });
155
156 let task = tokio::spawn({
158 let mut grpc_client = self.grpc_client.clone();
159 async move { Ok::<_, Status>(grpc_client.put(blobchunk_stream).await?.into_inner()) }
160 .in_current_span()
162 });
163
164 let sink = PollSender::new(tx)
166 .sink_map_err(|e| std::io::Error::new(std::io::ErrorKind::BrokenPipe, e));
167
168 let writer = SinkWriter::new(CopyToBytes::new(sink));
170
171 Box::new(GRPCBlobWriter {
172 task_and_writer: Some((task, writer)),
173 digest: None,
174 })
175 }
176
177 #[instrument(skip(self, digest), fields(blob.digest=%digest, instance_name=%self.instance_name), err)]
178 async fn chunks(&self, digest: &B3Digest) -> io::Result<Option<Vec<ChunkMeta>>> {
179 let resp = self
180 .grpc_client
181 .clone()
182 .stat(proto::StatBlobRequest {
183 digest: (*digest).into(),
184 send_chunks: true,
185 ..Default::default()
186 })
187 .await;
188
189 match resp {
190 Err(e) if e.code() == Code::NotFound => Ok(None),
191 Err(e) => Err(io::Error::other(e)),
192 Ok(resp) => {
193 let resp = resp.into_inner();
194
195 resp.validate()
196 .map_err(|e| std::io::Error::new(io::ErrorKind::InvalidData, e))?;
197
198 Ok(Some(resp.chunks))
199 }
200 }
201 }
202}
203
204#[derive(serde::Deserialize, Debug)]
205#[serde(deny_unknown_fields)]
206pub struct GRPCBlobServiceConfig {
207 url: String,
208}
209
210impl TryFrom<url::Url> for GRPCBlobServiceConfig {
211 type Error = Box<dyn std::error::Error + Send + Sync>;
212 fn try_from(url: url::Url) -> Result<Self, Self::Error> {
213 Ok(GRPCBlobServiceConfig {
218 url: url.to_string(),
219 })
220 }
221}
222
223#[async_trait]
224impl ServiceBuilder for GRPCBlobServiceConfig {
225 type Output = dyn BlobService;
226 async fn build<'a>(
227 &'a self,
228 instance_name: &str,
229 _context: &CompositionContext,
230 ) -> Result<Arc<dyn BlobService>, Box<dyn std::error::Error + Send + Sync + 'static>> {
231 let client = proto::blob_service_client::BlobServiceClient::with_interceptor(
232 crate::tonic::channel_from_url(&self.url.parse()?).await?,
233 snix_tracing::propagate::tonic::send_trace,
234 );
235 Ok(Arc::new(GRPCBlobService::from_client(
236 instance_name.to_string(),
237 client,
238 )))
239 }
240}
241
242pub struct GRPCBlobWriter<W: tokio::io::AsyncWrite> {
243 task_and_writer: Option<(JoinHandle<Result<proto::PutBlobResponse, Status>>, W)>,
245
246 digest: Option<B3Digest>,
248}
249
250#[async_trait]
251impl<W: tokio::io::AsyncWrite + Send + Sync + Unpin + 'static> BlobWriter for GRPCBlobWriter<W> {
252 async fn close(&mut self) -> io::Result<B3Digest> {
253 if self.task_and_writer.is_none() {
254 match &self.digest {
257 Some(digest) => Ok(*digest),
258 None => Err(io::Error::new(io::ErrorKind::BrokenPipe, "already closed")),
259 }
260 } else {
261 let (task, mut writer) = self.task_and_writer.take().unwrap();
262
263 writer.shutdown().await?;
266
267 match task.await? {
272 Ok(resp) => {
273 let digest_len = resp.digest.len();
275 let digest: B3Digest = resp.digest.try_into().map_err(|_| {
276 io::Error::other(format!(
277 "invalid root digest length {digest_len} in response"
278 ))
279 })?;
280 self.digest = Some(digest);
281 Ok(digest)
282 }
283 Err(e) => Err(io::Error::other(e.to_string())),
284 }
285 }
286 }
287}
288
289impl<W: tokio::io::AsyncWrite + Unpin> tokio::io::AsyncWrite for GRPCBlobWriter<W> {
290 fn poll_write(
291 mut self: std::pin::Pin<&mut Self>,
292 cx: &mut std::task::Context<'_>,
293 buf: &[u8],
294 ) -> std::task::Poll<Result<usize, io::Error>> {
295 match &mut self.task_and_writer {
296 None => Poll::Ready(Err(io::Error::new(
297 io::ErrorKind::NotConnected,
298 "already closed",
299 ))),
300 Some((_, writer)) => {
301 let pinned_writer = pin!(writer);
302 pinned_writer.poll_write(cx, buf)
303 }
304 }
305 }
306
307 fn poll_flush(
308 mut self: std::pin::Pin<&mut Self>,
309 cx: &mut std::task::Context<'_>,
310 ) -> std::task::Poll<Result<(), io::Error>> {
311 match &mut self.task_and_writer {
312 None => Poll::Ready(Err(io::Error::new(
313 io::ErrorKind::NotConnected,
314 "already closed",
315 ))),
316 Some((_, writer)) => {
317 let pinned_writer = pin!(writer);
318 pinned_writer.poll_flush(cx)
319 }
320 }
321 }
322
323 fn poll_shutdown(
324 self: std::pin::Pin<&mut Self>,
325 _cx: &mut std::task::Context<'_>,
326 ) -> std::task::Poll<Result<(), io::Error>> {
327 Poll::Ready(Ok(()))
330 }
331}
332
333#[cfg(test)]
334mod tests {
335 use std::time::Duration;
336
337 use tempfile::TempDir;
338 use tokio::net::UnixListener;
339 use tokio_retry::Retry;
340 use tokio_retry::strategy::ExponentialBackoff;
341 use tokio_stream::wrappers::UnixListenerStream;
342
343 use crate::blobservice::MemoryBlobService;
344 use crate::fixtures;
345 use crate::proto::GRPCBlobServiceWrapper;
346 use crate::proto::blob_service_client::BlobServiceClient;
347
348 use super::BlobService;
349 use super::GRPCBlobService;
350
351 #[tokio::test]
353 async fn test_valid_unix_path_ping_pong() {
354 let tmpdir = TempDir::new().unwrap();
355 let socket_path = tmpdir.path().join("daemon");
356
357 let path_clone = socket_path.clone();
358
359 tokio::spawn(async {
361 let uds = UnixListener::bind(path_clone).unwrap();
362 let uds_stream = UnixListenerStream::new(uds);
363
364 let mut server = tonic::transport::Server::builder();
366 let router =
367 server.add_service(crate::proto::blob_service_server::BlobServiceServer::new(
368 GRPCBlobServiceWrapper::new(
369 Box::<MemoryBlobService>::default() as Box<dyn BlobService>
370 ),
371 ));
372 router.serve_with_incoming(uds_stream).await
373 });
374
375 Retry::spawn(
377 ExponentialBackoff::from_millis(20).max_delay(Duration::from_secs(10)),
378 || async {
379 if socket_path.exists() {
380 Ok(())
381 } else {
382 Err(())
383 }
384 },
385 )
386 .await
387 .expect("failed to wait for socket");
388
389 let grpc_client = {
391 let url = url::Url::parse(&format!(
392 "grpc+unix:{}?wait-connect=1",
393 socket_path.display()
394 ))
395 .expect("must parse");
396 let client = BlobServiceClient::new(
397 crate::tonic::channel_from_url(&url)
398 .await
399 .expect("must succeed"),
400 );
401 GRPCBlobService::from_client("root".into(), client)
402 };
403
404 let has = grpc_client
405 .has(&fixtures::BLOB_A_DIGEST)
406 .await
407 .expect("must not be err");
408
409 assert!(!has);
410 }
411}