snix_castore/blobservice/
grpc.rs

1use super::{BlobReader, BlobService, BlobWriter, ChunkedReader};
2use crate::composition::{CompositionContext, ServiceBuilder};
3use crate::{
4    B3Digest,
5    proto::{self, stat_blob_response::ChunkMeta},
6};
7use futures::sink::SinkExt;
8use std::{
9    io::{self, Cursor},
10    pin::pin,
11    sync::Arc,
12    task::Poll,
13};
14use tokio::io::AsyncWriteExt;
15use tokio::task::JoinHandle;
16use tokio_stream::{StreamExt, wrappers::ReceiverStream};
17use tokio_util::{
18    io::{CopyToBytes, SinkWriter},
19    sync::PollSender,
20};
21use tonic::{Code, Status, async_trait};
22use tracing::{Instrument, Span, instrument};
23
24/// Connects to a (remote) snix-store BlobService over gRPC.
25#[derive(Clone)]
26pub struct GRPCBlobService<T> {
27    instance_name: String,
28    /// The internal reference to a gRPC client.
29    /// Cloning it is cheap, and it internally handles concurrent requests.
30    grpc_client: proto::blob_service_client::BlobServiceClient<T>,
31}
32
33impl<T> GRPCBlobService<T> {
34    /// construct a [GRPCBlobService] from a [proto::blob_service_client::BlobServiceClient].
35    pub fn from_client(
36        instance_name: String,
37        grpc_client: proto::blob_service_client::BlobServiceClient<T>,
38    ) -> Self {
39        Self {
40            instance_name,
41            grpc_client,
42        }
43    }
44}
45
46#[async_trait]
47impl<T> BlobService for GRPCBlobService<T>
48where
49    T: tonic::client::GrpcService<tonic::body::Body> + Send + Sync + Clone + 'static,
50    T::ResponseBody: tonic::codegen::Body<Data = tonic::codegen::Bytes> + Send + 'static,
51    <T::ResponseBody as tonic::codegen::Body>::Error: Into<tonic::codegen::StdError> + Send,
52    T::Future: Send,
53{
54    #[instrument(skip(self, digest), fields(blob.digest=%digest, instance_name=%self.instance_name))]
55    async fn has(&self, digest: &B3Digest) -> io::Result<bool> {
56        match self
57            .grpc_client
58            .clone()
59            .stat(proto::StatBlobRequest {
60                digest: (*digest).into(),
61                ..Default::default()
62            })
63            .await
64        {
65            Ok(_blob_meta) => Ok(true),
66            Err(e) if e.code() == Code::NotFound => Ok(false),
67            Err(e) => Err(io::Error::other(e)),
68        }
69    }
70
71    #[instrument(skip(self, digest), fields(blob.digest=%digest, instance_name=%self.instance_name), err)]
72    async fn open_read(&self, digest: &B3Digest) -> io::Result<Option<Box<dyn BlobReader>>> {
73        // First try to get a list of chunks. In case there's only one chunk returned,
74        // buffer its data into a Vec, otherwise use a ChunkedReader.
75        // We previously used NaiveSeeker here, but userland likes to seek backwards too often,
76        // and without store composition this will get very noisy.
77        // FUTUREWORK: use CombinedBlobService and store composition.
78        match self.chunks(digest).await {
79            Ok(None) => Ok(None),
80            Ok(Some(chunks)) => {
81                if chunks.is_empty() || chunks.len() == 1 {
82                    // No more granular chunking info, treat this as an individual chunk.
83                    // Get a stream of [proto::BlobChunk], or return an error if the blob
84                    // doesn't exist.
85                    return match self
86                        .grpc_client
87                        .clone()
88                        .read(proto::ReadBlobRequest {
89                            digest: (*digest).into(),
90                        })
91                        .await
92                    {
93                        Ok(stream) => {
94                            let data_stream = stream.into_inner().map(|e| {
95                                e.map(|c| c.data)
96                                    .map_err(|s| std::io::Error::new(io::ErrorKind::InvalidData, s))
97                            });
98
99                            // Use StreamReader::new to convert to an AsyncRead.
100                            let mut data_reader = tokio_util::io::StreamReader::new(data_stream);
101
102                            let mut buf = Vec::new();
103                            // TODO: only do this up to a certain limit.
104                            tokio::io::copy(&mut data_reader, &mut buf).await?;
105
106                            Ok(Some(Box::new(Cursor::new(buf))))
107                        }
108                        Err(e) if e.code() == Code::NotFound => Ok(None),
109                        Err(e) => Err(io::Error::other(e)),
110                    };
111                }
112
113                // The chunked case. Let ChunkedReader do individual reads.
114                // TODO: we should store the chunking data in some local cache,
115                // so `ChunkedReader` doesn't call `self.chunks` *again* for every chunk.
116                // Think about how store composition will fix this.
117                let chunked_reader = ChunkedReader::from_chunks(
118                    chunks.into_iter().map(|chunk| {
119                        (
120                            chunk.digest.try_into().expect("invalid b3 digest"),
121                            chunk.size,
122                        )
123                    }),
124                    Arc::new(self.clone()) as Arc<dyn BlobService>,
125                );
126                Ok(Some(Box::new(chunked_reader)))
127            }
128            Err(e) => Err(e)?,
129        }
130    }
131
132    /// Returns a BlobWriter, that'll internally wrap each write in a
133    /// [proto::BlobChunk], which is send to the gRPC server.
134    #[instrument(skip_all, fields(instance_name=%self.instance_name))]
135    async fn open_write(&self) -> Box<dyn BlobWriter> {
136        // set up an mpsc channel passing around Bytes.
137        let (tx, rx) = tokio::sync::mpsc::channel::<bytes::Bytes>(64);
138
139        // bytes arriving on the RX side are wrapped inside a
140        // [proto::BlobChunk], and a [ReceiverStream] is constructed.
141        let span = Span::current();
142        let blobchunk_stream = ReceiverStream::new(rx).map(move |x| {
143            let span = tracing::trace_span!(
144                parent: &span,
145                "blob_chunk",
146                blob.size = x.len()
147            );
148
149            span.in_scope(|| {
150                tracing::trace!("constructing BlobChunk");
151            });
152
153            proto::BlobChunk { data: x }
154        });
155
156        // spawn the gRPC put request, which will read from blobchunk_stream.
157        let task = tokio::spawn({
158            let mut grpc_client = self.grpc_client.clone();
159            async move { Ok::<_, Status>(grpc_client.put(blobchunk_stream).await?.into_inner()) }
160                // instrument the task with the current span, this is not done by default
161                .in_current_span()
162        });
163
164        // The tx part of the channel is converted to a sink of byte chunks.
165        let sink = PollSender::new(tx)
166            .sink_map_err(|e| std::io::Error::new(std::io::ErrorKind::BrokenPipe, e));
167
168        // … which is turned into an [tokio::io::AsyncWrite].
169        let writer = SinkWriter::new(CopyToBytes::new(sink));
170
171        Box::new(GRPCBlobWriter {
172            task_and_writer: Some((task, writer)),
173            digest: None,
174        })
175    }
176
177    #[instrument(skip(self, digest), fields(blob.digest=%digest, instance_name=%self.instance_name), err)]
178    async fn chunks(&self, digest: &B3Digest) -> io::Result<Option<Vec<ChunkMeta>>> {
179        let resp = self
180            .grpc_client
181            .clone()
182            .stat(proto::StatBlobRequest {
183                digest: (*digest).into(),
184                send_chunks: true,
185                ..Default::default()
186            })
187            .await;
188
189        match resp {
190            Err(e) if e.code() == Code::NotFound => Ok(None),
191            Err(e) => Err(io::Error::other(e)),
192            Ok(resp) => {
193                let resp = resp.into_inner();
194
195                resp.validate()
196                    .map_err(|e| std::io::Error::new(io::ErrorKind::InvalidData, e))?;
197
198                Ok(Some(resp.chunks))
199            }
200        }
201    }
202}
203
204#[derive(serde::Deserialize, Debug)]
205#[serde(deny_unknown_fields)]
206pub struct GRPCBlobServiceConfig {
207    url: String,
208}
209
210impl TryFrom<url::Url> for GRPCBlobServiceConfig {
211    type Error = Box<dyn std::error::Error + Send + Sync>;
212    fn try_from(url: url::Url) -> Result<Self, Self::Error> {
213        //   normally grpc+unix for unix sockets, and grpc+http(s) for the HTTP counterparts.
214        // - In the case of unix sockets, there must be a path, but may not be a host.
215        // - In the case of non-unix sockets, there must be a host, but no path.
216        // Constructing the channel is handled by snix_castore::channel::from_url.
217        Ok(GRPCBlobServiceConfig {
218            url: url.to_string(),
219        })
220    }
221}
222
223#[async_trait]
224impl ServiceBuilder for GRPCBlobServiceConfig {
225    type Output = dyn BlobService;
226    async fn build<'a>(
227        &'a self,
228        instance_name: &str,
229        _context: &CompositionContext,
230    ) -> Result<Arc<dyn BlobService>, Box<dyn std::error::Error + Send + Sync + 'static>> {
231        let client = proto::blob_service_client::BlobServiceClient::with_interceptor(
232            crate::tonic::channel_from_url(&self.url.parse()?).await?,
233            snix_tracing::propagate::tonic::send_trace,
234        );
235        Ok(Arc::new(GRPCBlobService::from_client(
236            instance_name.to_string(),
237            client,
238        )))
239    }
240}
241
242pub struct GRPCBlobWriter<W: tokio::io::AsyncWrite> {
243    /// The task containing the put request, and the inner writer, if we're still writing.
244    task_and_writer: Option<(JoinHandle<Result<proto::PutBlobResponse, Status>>, W)>,
245
246    /// The digest that has been returned, if we successfully closed.
247    digest: Option<B3Digest>,
248}
249
250#[async_trait]
251impl<W: tokio::io::AsyncWrite + Send + Sync + Unpin + 'static> BlobWriter for GRPCBlobWriter<W> {
252    async fn close(&mut self) -> io::Result<B3Digest> {
253        if self.task_and_writer.is_none() {
254            // if we're already closed, return the b3 digest, which must exist.
255            // If it doesn't, we already closed and failed once, and didn't handle the error.
256            match &self.digest {
257                Some(digest) => Ok(*digest),
258                None => Err(io::Error::new(io::ErrorKind::BrokenPipe, "already closed")),
259            }
260        } else {
261            let (task, mut writer) = self.task_and_writer.take().unwrap();
262
263            // invoke shutdown, so the inner writer closes its internal tx side of
264            // the channel.
265            writer.shutdown().await?;
266
267            // block on the RPC call to return.
268            // This ensures all chunks are sent out, and have been received by the
269            // backend.
270
271            match task.await? {
272                Ok(resp) => {
273                    // return the digest from the response, and store it in self.digest for subsequent closes.
274                    let digest_len = resp.digest.len();
275                    let digest: B3Digest = resp.digest.try_into().map_err(|_| {
276                        io::Error::other(format!(
277                            "invalid root digest length {digest_len} in response"
278                        ))
279                    })?;
280                    self.digest = Some(digest);
281                    Ok(digest)
282                }
283                Err(e) => Err(io::Error::other(e.to_string())),
284            }
285        }
286    }
287}
288
289impl<W: tokio::io::AsyncWrite + Unpin> tokio::io::AsyncWrite for GRPCBlobWriter<W> {
290    fn poll_write(
291        mut self: std::pin::Pin<&mut Self>,
292        cx: &mut std::task::Context<'_>,
293        buf: &[u8],
294    ) -> std::task::Poll<Result<usize, io::Error>> {
295        match &mut self.task_and_writer {
296            None => Poll::Ready(Err(io::Error::new(
297                io::ErrorKind::NotConnected,
298                "already closed",
299            ))),
300            Some((_, writer)) => {
301                let pinned_writer = pin!(writer);
302                pinned_writer.poll_write(cx, buf)
303            }
304        }
305    }
306
307    fn poll_flush(
308        mut self: std::pin::Pin<&mut Self>,
309        cx: &mut std::task::Context<'_>,
310    ) -> std::task::Poll<Result<(), io::Error>> {
311        match &mut self.task_and_writer {
312            None => Poll::Ready(Err(io::Error::new(
313                io::ErrorKind::NotConnected,
314                "already closed",
315            ))),
316            Some((_, writer)) => {
317                let pinned_writer = pin!(writer);
318                pinned_writer.poll_flush(cx)
319            }
320        }
321    }
322
323    fn poll_shutdown(
324        self: std::pin::Pin<&mut Self>,
325        _cx: &mut std::task::Context<'_>,
326    ) -> std::task::Poll<Result<(), io::Error>> {
327        // TODO(raitobezarius): this might not be a graceful shutdown of the
328        // channel inside the gRPC connection.
329        Poll::Ready(Ok(()))
330    }
331}
332
333#[cfg(test)]
334mod tests {
335    use std::time::Duration;
336
337    use tempfile::TempDir;
338    use tokio::net::UnixListener;
339    use tokio_retry::Retry;
340    use tokio_retry::strategy::ExponentialBackoff;
341    use tokio_stream::wrappers::UnixListenerStream;
342
343    use crate::blobservice::MemoryBlobService;
344    use crate::fixtures;
345    use crate::proto::GRPCBlobServiceWrapper;
346    use crate::proto::blob_service_client::BlobServiceClient;
347
348    use super::BlobService;
349    use super::GRPCBlobService;
350
351    /// This ensures connecting via gRPC works as expected.
352    #[tokio::test]
353    async fn test_valid_unix_path_ping_pong() {
354        let tmpdir = TempDir::new().unwrap();
355        let socket_path = tmpdir.path().join("daemon");
356
357        let path_clone = socket_path.clone();
358
359        // Spin up a server
360        tokio::spawn(async {
361            let uds = UnixListener::bind(path_clone).unwrap();
362            let uds_stream = UnixListenerStream::new(uds);
363
364            // spin up a new server
365            let mut server = tonic::transport::Server::builder();
366            let router =
367                server.add_service(crate::proto::blob_service_server::BlobServiceServer::new(
368                    GRPCBlobServiceWrapper::new(
369                        Box::<MemoryBlobService>::default() as Box<dyn BlobService>
370                    ),
371                ));
372            router.serve_with_incoming(uds_stream).await
373        });
374
375        // wait for the socket to be created
376        Retry::spawn(
377            ExponentialBackoff::from_millis(20).max_delay(Duration::from_secs(10)),
378            || async {
379                if socket_path.exists() {
380                    Ok(())
381                } else {
382                    Err(())
383                }
384            },
385        )
386        .await
387        .expect("failed to wait for socket");
388
389        // prepare a client
390        let grpc_client = {
391            let url = url::Url::parse(&format!(
392                "grpc+unix:{}?wait-connect=1",
393                socket_path.display()
394            ))
395            .expect("must parse");
396            let client = BlobServiceClient::new(
397                crate::tonic::channel_from_url(&url)
398                    .await
399                    .expect("must succeed"),
400            );
401            GRPCBlobService::from_client("root".into(), client)
402        };
403
404        let has = grpc_client
405            .has(&fixtures::BLOB_A_DIGEST)
406            .await
407            .expect("must not be err");
408
409        assert!(!has);
410    }
411}