snix_castore/blobservice/
grpc.rs

1use super::{BlobReader, BlobService, BlobWriter, ChunkedReader};
2use crate::composition::{CompositionContext, ServiceBuilder};
3use crate::{
4    B3Digest,
5    proto::{self, stat_blob_response::ChunkMeta},
6};
7use futures::sink::SinkExt;
8use std::{
9    io::{self, Cursor},
10    pin::pin,
11    sync::Arc,
12    task::Poll,
13};
14use tokio::io::AsyncWriteExt;
15use tokio::task::JoinHandle;
16use tokio_stream::{StreamExt, wrappers::ReceiverStream};
17use tokio_util::{
18    io::{CopyToBytes, SinkWriter},
19    sync::PollSender,
20};
21use tonic::{Code, Status, async_trait};
22use tracing::{Instrument as _, instrument};
23
24/// Connects to a (remote) snix-store BlobService over gRPC.
25#[derive(Clone)]
26pub struct GRPCBlobService<T> {
27    instance_name: String,
28    /// The internal reference to a gRPC client.
29    /// Cloning it is cheap, and it internally handles concurrent requests.
30    grpc_client: proto::blob_service_client::BlobServiceClient<T>,
31}
32
33impl<T> GRPCBlobService<T> {
34    /// construct a [GRPCBlobService] from a [proto::blob_service_client::BlobServiceClient].
35    pub fn from_client(
36        instance_name: String,
37        grpc_client: proto::blob_service_client::BlobServiceClient<T>,
38    ) -> Self {
39        Self {
40            instance_name,
41            grpc_client,
42        }
43    }
44}
45
46#[async_trait]
47impl<T> BlobService for GRPCBlobService<T>
48where
49    T: tonic::client::GrpcService<tonic::body::BoxBody> + Send + Sync + Clone + 'static,
50    T::ResponseBody: tonic::codegen::Body<Data = tonic::codegen::Bytes> + Send + 'static,
51    <T::ResponseBody as tonic::codegen::Body>::Error: Into<tonic::codegen::StdError> + Send,
52    T::Future: Send,
53{
54    #[instrument(skip(self, digest), fields(blob.digest=%digest, instance_name=%self.instance_name))]
55    async fn has(&self, digest: &B3Digest) -> io::Result<bool> {
56        match self
57            .grpc_client
58            .clone()
59            .stat(proto::StatBlobRequest {
60                digest: digest.clone().into(),
61                ..Default::default()
62            })
63            .await
64        {
65            Ok(_blob_meta) => Ok(true),
66            Err(e) if e.code() == Code::NotFound => Ok(false),
67            Err(e) => Err(io::Error::new(io::ErrorKind::Other, e)),
68        }
69    }
70
71    #[instrument(skip(self, digest), fields(blob.digest=%digest, instance_name=%self.instance_name), err)]
72    async fn open_read(&self, digest: &B3Digest) -> io::Result<Option<Box<dyn BlobReader>>> {
73        // First try to get a list of chunks. In case there's only one chunk returned,
74        // buffer its data into a Vec, otherwise use a ChunkedReader.
75        // We previously used NaiveSeeker here, but userland likes to seek backwards too often,
76        // and without store composition this will get very noisy.
77        // FUTUREWORK: use CombinedBlobService and store composition.
78        match self.chunks(digest).await {
79            Ok(None) => Ok(None),
80            Ok(Some(chunks)) => {
81                if chunks.is_empty() || chunks.len() == 1 {
82                    // No more granular chunking info, treat this as an individual chunk.
83                    // Get a stream of [proto::BlobChunk], or return an error if the blob
84                    // doesn't exist.
85                    return match self
86                        .grpc_client
87                        .clone()
88                        .read(proto::ReadBlobRequest {
89                            digest: digest.clone().into(),
90                        })
91                        .await
92                    {
93                        Ok(stream) => {
94                            let data_stream = stream.into_inner().map(|e| {
95                                e.map(|c| c.data)
96                                    .map_err(|s| std::io::Error::new(io::ErrorKind::InvalidData, s))
97                            });
98
99                            // Use StreamReader::new to convert to an AsyncRead.
100                            let mut data_reader = tokio_util::io::StreamReader::new(data_stream);
101
102                            let mut buf = Vec::new();
103                            // TODO: only do this up to a certain limit.
104                            tokio::io::copy(&mut data_reader, &mut buf).await?;
105
106                            Ok(Some(Box::new(Cursor::new(buf))))
107                        }
108                        Err(e) if e.code() == Code::NotFound => Ok(None),
109                        Err(e) => Err(io::Error::new(io::ErrorKind::Other, e)),
110                    };
111                }
112
113                // The chunked case. Let ChunkedReader do individual reads.
114                // TODO: we should store the chunking data in some local cache,
115                // so `ChunkedReader` doesn't call `self.chunks` *again* for every chunk.
116                // Think about how store composition will fix this.
117                let chunked_reader = ChunkedReader::from_chunks(
118                    chunks.into_iter().map(|chunk| {
119                        (
120                            chunk.digest.try_into().expect("invalid b3 digest"),
121                            chunk.size,
122                        )
123                    }),
124                    Arc::new(self.clone()) as Arc<dyn BlobService>,
125                );
126                Ok(Some(Box::new(chunked_reader)))
127            }
128            Err(e) => Err(e)?,
129        }
130    }
131
132    /// Returns a BlobWriter, that'll internally wrap each write in a
133    /// [proto::BlobChunk], which is send to the gRPC server.
134    #[instrument(skip_all, fields(instance_name=%self.instance_name))]
135    async fn open_write(&self) -> Box<dyn BlobWriter> {
136        // set up an mpsc channel passing around Bytes.
137        let (tx, rx) = tokio::sync::mpsc::channel::<bytes::Bytes>(10);
138
139        // bytes arriving on the RX side are wrapped inside a
140        // [proto::BlobChunk], and a [ReceiverStream] is constructed.
141        let blobchunk_stream = ReceiverStream::new(rx).map(|x| proto::BlobChunk { data: x });
142
143        // spawn the gRPC put request, which will read from blobchunk_stream.
144        let task = tokio::spawn({
145            let mut grpc_client = self.grpc_client.clone();
146            async move { Ok::<_, Status>(grpc_client.put(blobchunk_stream).await?.into_inner()) }
147                // instrument the task with the current span, this is not done by default
148                .in_current_span()
149        });
150
151        // The tx part of the channel is converted to a sink of byte chunks.
152        let sink = PollSender::new(tx)
153            .sink_map_err(|e| std::io::Error::new(std::io::ErrorKind::BrokenPipe, e));
154
155        // … which is turned into an [tokio::io::AsyncWrite].
156        let writer = SinkWriter::new(CopyToBytes::new(sink));
157
158        Box::new(GRPCBlobWriter {
159            task_and_writer: Some((task, writer)),
160            digest: None,
161        })
162    }
163
164    #[instrument(skip(self, digest), fields(blob.digest=%digest, instance_name=%self.instance_name), err)]
165    async fn chunks(&self, digest: &B3Digest) -> io::Result<Option<Vec<ChunkMeta>>> {
166        let resp = self
167            .grpc_client
168            .clone()
169            .stat(proto::StatBlobRequest {
170                digest: digest.clone().into(),
171                send_chunks: true,
172                ..Default::default()
173            })
174            .await;
175
176        match resp {
177            Err(e) if e.code() == Code::NotFound => Ok(None),
178            Err(e) => Err(io::Error::new(io::ErrorKind::Other, e)),
179            Ok(resp) => {
180                let resp = resp.into_inner();
181
182                resp.validate()
183                    .map_err(|e| std::io::Error::new(io::ErrorKind::InvalidData, e))?;
184
185                Ok(Some(resp.chunks))
186            }
187        }
188    }
189}
190
191#[derive(serde::Deserialize, Debug)]
192#[serde(deny_unknown_fields)]
193pub struct GRPCBlobServiceConfig {
194    url: String,
195}
196
197impl TryFrom<url::Url> for GRPCBlobServiceConfig {
198    type Error = Box<dyn std::error::Error + Send + Sync>;
199    fn try_from(url: url::Url) -> Result<Self, Self::Error> {
200        //   normally grpc+unix for unix sockets, and grpc+http(s) for the HTTP counterparts.
201        // - In the case of unix sockets, there must be a path, but may not be a host.
202        // - In the case of non-unix sockets, there must be a host, but no path.
203        // Constructing the channel is handled by snix_castore::channel::from_url.
204        Ok(GRPCBlobServiceConfig {
205            url: url.to_string(),
206        })
207    }
208}
209
210#[async_trait]
211impl ServiceBuilder for GRPCBlobServiceConfig {
212    type Output = dyn BlobService;
213    async fn build<'a>(
214        &'a self,
215        instance_name: &str,
216        _context: &CompositionContext,
217    ) -> Result<Arc<dyn BlobService>, Box<dyn std::error::Error + Send + Sync + 'static>> {
218        let client = proto::blob_service_client::BlobServiceClient::new(
219            crate::tonic::channel_from_url(&self.url.parse()?).await?,
220        );
221        Ok(Arc::new(GRPCBlobService::from_client(
222            instance_name.to_string(),
223            client,
224        )))
225    }
226}
227
228pub struct GRPCBlobWriter<W: tokio::io::AsyncWrite> {
229    /// The task containing the put request, and the inner writer, if we're still writing.
230    task_and_writer: Option<(JoinHandle<Result<proto::PutBlobResponse, Status>>, W)>,
231
232    /// The digest that has been returned, if we successfully closed.
233    digest: Option<B3Digest>,
234}
235
236#[async_trait]
237impl<W: tokio::io::AsyncWrite + Send + Sync + Unpin + 'static> BlobWriter for GRPCBlobWriter<W> {
238    async fn close(&mut self) -> io::Result<B3Digest> {
239        if self.task_and_writer.is_none() {
240            // if we're already closed, return the b3 digest, which must exist.
241            // If it doesn't, we already closed and failed once, and didn't handle the error.
242            match &self.digest {
243                Some(digest) => Ok(digest.clone()),
244                None => Err(io::Error::new(io::ErrorKind::BrokenPipe, "already closed")),
245            }
246        } else {
247            let (task, mut writer) = self.task_and_writer.take().unwrap();
248
249            // invoke shutdown, so the inner writer closes its internal tx side of
250            // the channel.
251            writer.shutdown().await?;
252
253            // block on the RPC call to return.
254            // This ensures all chunks are sent out, and have been received by the
255            // backend.
256
257            match task.await? {
258                Ok(resp) => {
259                    // return the digest from the response, and store it in self.digest for subsequent closes.
260                    let digest_len = resp.digest.len();
261                    let digest: B3Digest = resp.digest.try_into().map_err(|_| {
262                        io::Error::new(
263                            io::ErrorKind::Other,
264                            format!("invalid root digest length {} in response", digest_len),
265                        )
266                    })?;
267                    self.digest = Some(digest.clone());
268                    Ok(digest)
269                }
270                Err(e) => Err(io::Error::new(io::ErrorKind::Other, e.to_string())),
271            }
272        }
273    }
274}
275
276impl<W: tokio::io::AsyncWrite + Unpin> tokio::io::AsyncWrite for GRPCBlobWriter<W> {
277    fn poll_write(
278        mut self: std::pin::Pin<&mut Self>,
279        cx: &mut std::task::Context<'_>,
280        buf: &[u8],
281    ) -> std::task::Poll<Result<usize, io::Error>> {
282        match &mut self.task_and_writer {
283            None => Poll::Ready(Err(io::Error::new(
284                io::ErrorKind::NotConnected,
285                "already closed",
286            ))),
287            Some((_, writer)) => {
288                let pinned_writer = pin!(writer);
289                pinned_writer.poll_write(cx, buf)
290            }
291        }
292    }
293
294    fn poll_flush(
295        mut self: std::pin::Pin<&mut Self>,
296        cx: &mut std::task::Context<'_>,
297    ) -> std::task::Poll<Result<(), io::Error>> {
298        match &mut self.task_and_writer {
299            None => Poll::Ready(Err(io::Error::new(
300                io::ErrorKind::NotConnected,
301                "already closed",
302            ))),
303            Some((_, writer)) => {
304                let pinned_writer = pin!(writer);
305                pinned_writer.poll_flush(cx)
306            }
307        }
308    }
309
310    fn poll_shutdown(
311        self: std::pin::Pin<&mut Self>,
312        _cx: &mut std::task::Context<'_>,
313    ) -> std::task::Poll<Result<(), io::Error>> {
314        // TODO(raitobezarius): this might not be a graceful shutdown of the
315        // channel inside the gRPC connection.
316        Poll::Ready(Ok(()))
317    }
318}
319
320#[cfg(test)]
321mod tests {
322    use std::time::Duration;
323
324    use tempfile::TempDir;
325    use tokio::net::UnixListener;
326    use tokio_retry::Retry;
327    use tokio_retry::strategy::ExponentialBackoff;
328    use tokio_stream::wrappers::UnixListenerStream;
329
330    use crate::blobservice::MemoryBlobService;
331    use crate::fixtures;
332    use crate::proto::GRPCBlobServiceWrapper;
333    use crate::proto::blob_service_client::BlobServiceClient;
334
335    use super::BlobService;
336    use super::GRPCBlobService;
337
338    /// This ensures connecting via gRPC works as expected.
339    #[tokio::test]
340    async fn test_valid_unix_path_ping_pong() {
341        let tmpdir = TempDir::new().unwrap();
342        let socket_path = tmpdir.path().join("daemon");
343
344        let path_clone = socket_path.clone();
345
346        // Spin up a server
347        tokio::spawn(async {
348            let uds = UnixListener::bind(path_clone).unwrap();
349            let uds_stream = UnixListenerStream::new(uds);
350
351            // spin up a new server
352            let mut server = tonic::transport::Server::builder();
353            let router =
354                server.add_service(crate::proto::blob_service_server::BlobServiceServer::new(
355                    GRPCBlobServiceWrapper::new(
356                        Box::<MemoryBlobService>::default() as Box<dyn BlobService>
357                    ),
358                ));
359            router.serve_with_incoming(uds_stream).await
360        });
361
362        // wait for the socket to be created
363        Retry::spawn(
364            ExponentialBackoff::from_millis(20).max_delay(Duration::from_secs(10)),
365            || async {
366                if socket_path.exists() {
367                    Ok(())
368                } else {
369                    Err(())
370                }
371            },
372        )
373        .await
374        .expect("failed to wait for socket");
375
376        // prepare a client
377        let grpc_client = {
378            let url = url::Url::parse(&format!(
379                "grpc+unix://{}?wait-connect=1",
380                socket_path.display()
381            ))
382            .expect("must parse");
383            let client = BlobServiceClient::new(
384                crate::tonic::channel_from_url(&url)
385                    .await
386                    .expect("must succeed"),
387            );
388            GRPCBlobService::from_client("root".into(), client)
389        };
390
391        let has = grpc_client
392            .has(&fixtures::BLOB_A_DIGEST)
393            .await
394            .expect("must not be err");
395
396        assert!(!has);
397    }
398}