snix_castore/directoryservice/
grpc.rs

1use std::collections::HashSet;
2
3use super::{Directory, DirectoryPutter, DirectoryService};
4use crate::composition::{CompositionContext, ServiceBuilder};
5use crate::proto::{self, get_directory_request::ByWhat};
6use crate::{B3Digest, DirectoryError, Error};
7use async_stream::try_stream;
8use futures::stream::BoxStream;
9use std::sync::Arc;
10use tokio::spawn;
11use tokio::sync::mpsc::UnboundedSender;
12use tokio::task::JoinHandle;
13use tokio_stream::wrappers::UnboundedReceiverStream;
14use tonic::{Code, Status, async_trait};
15use tracing::{Instrument as _, instrument, warn};
16
17/// Connects to a (remote) snix-store DirectoryService over gRPC.
18#[derive(Clone)]
19pub struct GRPCDirectoryService<T> {
20    instance_name: String,
21    /// The internal reference to a gRPC client.
22    /// Cloning it is cheap, and it internally handles concurrent requests.
23    grpc_client: proto::directory_service_client::DirectoryServiceClient<T>,
24}
25
26impl<T> GRPCDirectoryService<T> {
27    /// construct a [GRPCDirectoryService] from a [proto::directory_service_client::DirectoryServiceClient].
28    /// panics if called outside the context of a tokio runtime.
29    pub fn from_client(
30        instance_name: String,
31        grpc_client: proto::directory_service_client::DirectoryServiceClient<T>,
32    ) -> Self {
33        Self {
34            instance_name,
35            grpc_client,
36        }
37    }
38}
39
40#[async_trait]
41impl<T> DirectoryService for GRPCDirectoryService<T>
42where
43    T: tonic::client::GrpcService<tonic::body::BoxBody> + Send + Sync + Clone + 'static,
44    T::ResponseBody: tonic::codegen::Body<Data = tonic::codegen::Bytes> + Send + 'static,
45    <T::ResponseBody as tonic::codegen::Body>::Error: Into<tonic::codegen::StdError> + Send,
46    T::Future: Send,
47{
48    #[instrument(level = "trace", skip_all, fields(directory.digest = %digest, instance_name = %self.instance_name))]
49    async fn get(&self, digest: &B3Digest) -> Result<Option<Directory>, crate::Error> {
50        // Get a new handle to the gRPC client, and copy the digest.
51        let mut grpc_client = self.grpc_client.clone();
52        let digest_cpy = digest.clone();
53        let message = async move {
54            let mut s = grpc_client
55                .get(proto::GetDirectoryRequest {
56                    recursive: false,
57                    by_what: Some(ByWhat::Digest(digest_cpy.into())),
58                })
59                .await?
60                .into_inner();
61
62            // Retrieve the first message only, then close the stream (we set recursive to false)
63            s.message().await
64        };
65
66        let digest = digest.clone();
67        match message.await {
68            Ok(Some(directory)) => {
69                // Validate the retrieved Directory indeed has the
70                // digest we expect it to have, to detect corruptions.
71                let actual_digest = directory.digest();
72                if actual_digest != digest {
73                    Err(crate::Error::StorageError(format!(
74                        "requested directory with digest {}, but got {}",
75                        digest, actual_digest
76                    )))
77                } else {
78                    Ok(Some(directory.try_into().map_err(|_| {
79                        Error::StorageError("invalid root digest length in response".to_string())
80                    })?))
81                }
82            }
83            Ok(None) => Ok(None),
84            Err(e) if e.code() == Code::NotFound => Ok(None),
85            Err(e) => Err(crate::Error::StorageError(e.to_string())),
86        }
87    }
88
89    #[instrument(level = "trace", skip_all, fields(directory.digest = %directory.digest(), instance_name = %self.instance_name))]
90    async fn put(&self, directory: Directory) -> Result<B3Digest, crate::Error> {
91        let resp = self
92            .grpc_client
93            .clone()
94            .put(tokio_stream::once(proto::Directory::from(directory)))
95            .await;
96
97        match resp {
98            Ok(put_directory_resp) => Ok(put_directory_resp
99                .into_inner()
100                .root_digest
101                .try_into()
102                .map_err(|_| {
103                    Error::StorageError("invalid root digest length in response".to_string())
104                })?),
105            Err(e) => Err(crate::Error::StorageError(e.to_string())),
106        }
107    }
108
109    #[instrument(level = "trace", skip_all, fields(directory.digest = %root_directory_digest, instance_name = %self.instance_name))]
110    fn get_recursive(
111        &self,
112        root_directory_digest: &B3Digest,
113    ) -> BoxStream<'static, Result<Directory, Error>> {
114        let mut grpc_client = self.grpc_client.clone();
115        let root_directory_digest = root_directory_digest.clone();
116
117        let stream = try_stream! {
118            let mut stream = grpc_client
119                .get(proto::GetDirectoryRequest {
120                    recursive: true,
121                    by_what: Some(ByWhat::Digest(root_directory_digest.clone().into())),
122                })
123                .await
124                .map_err(|e| crate::Error::StorageError(e.to_string()))?
125                .into_inner();
126
127            // The Directory digests we received so far
128            let mut received_directory_digests: HashSet<B3Digest> = HashSet::new();
129            // The Directory digests we're still expecting to get sent.
130            let mut expected_directory_digests: HashSet<B3Digest> = HashSet::from([root_directory_digest.clone()]);
131
132            loop {
133                match stream.message().await {
134                    Ok(Some(directory)) => {
135                        // validate we actually expected that directory, and move it from expected to received.
136                        let directory_digest = directory.digest();
137                        let was_expected = expected_directory_digests.remove(&directory_digest);
138                        if !was_expected {
139                            // FUTUREWORK: dumb clients might send the same stuff twice.
140                            // as a fallback, we might want to tolerate receiving
141                            // it if it's in received_directory_digests (as that
142                            // means it once was in expected_directory_digests)
143                            Err(crate::Error::StorageError(format!(
144                                "received unexpected directory {}",
145                                directory_digest
146                            )))?;
147                        }
148                        received_directory_digests.insert(directory_digest);
149
150                        // register all children in expected_directory_digests.
151                        for child_directory in &directory.directories {
152                            // We ran validate() above, so we know these digests must be correct.
153                            let child_directory_digest =
154                                child_directory.digest.clone().try_into().unwrap();
155
156                            expected_directory_digests
157                                .insert(child_directory_digest);
158                        }
159
160                        let directory = directory.try_into()
161                            .map_err(|e: DirectoryError| Error::StorageError(e.to_string()))?;
162
163                        yield directory;
164                    },
165                    Ok(None) if expected_directory_digests.len() == 1 && expected_directory_digests.contains(&root_directory_digest) => {
166                        // The root directory of the requested closure was not found, return an
167                        // empty stream
168                        return
169                    }
170                    Ok(None) => {
171                        // The stream has ended
172                        let diff_len = expected_directory_digests
173                            // Account for directories which have been referenced more than once,
174                            // but only received once since they were deduplicated
175                            .difference(&received_directory_digests)
176                            .count();
177                        // If this is not empty, then the closure is incomplete
178                        if diff_len != 0 {
179                            Err(crate::Error::StorageError(format!(
180                                "still expected {} directories, but got premature end of stream",
181                                diff_len
182                            )))?
183                        } else {
184                            return
185                        }
186                    },
187                    Err(e) => {
188                        Err(crate::Error::StorageError(e.to_string()))?;
189                    },
190                }
191            }
192        };
193
194        Box::pin(stream)
195    }
196
197    #[instrument(skip_all)]
198    fn put_multiple_start(&self) -> Box<(dyn DirectoryPutter + 'static)> {
199        let (tx, rx) = tokio::sync::mpsc::unbounded_channel();
200
201        let task = spawn({
202            let mut grpc_client = self.grpc_client.clone();
203
204            async move {
205                Ok::<_, Status>(
206                    grpc_client
207                        .put(UnboundedReceiverStream::new(rx))
208                        .await?
209                        .into_inner(),
210                )
211            }
212            // instrument the task with the current span, this is not done by default
213            .in_current_span()
214        });
215
216        Box::new(GRPCPutter {
217            rq: Some((task, tx)),
218        })
219    }
220}
221
222#[derive(serde::Deserialize, Debug)]
223#[serde(deny_unknown_fields)]
224pub struct GRPCDirectoryServiceConfig {
225    url: String,
226}
227
228impl TryFrom<url::Url> for GRPCDirectoryServiceConfig {
229    type Error = Box<dyn std::error::Error + Send + Sync>;
230    fn try_from(url: url::Url) -> Result<Self, Self::Error> {
231        //   This is normally grpc+unix for unix sockets, and grpc+http(s) for the HTTP counterparts.
232        // - In the case of unix sockets, there must be a path, but may not be a host.
233        // - In the case of non-unix sockets, there must be a host, but no path.
234        // Constructing the channel is handled by snix_castore::channel::from_url.
235        Ok(GRPCDirectoryServiceConfig {
236            url: url.to_string(),
237        })
238    }
239}
240
241#[async_trait]
242impl ServiceBuilder for GRPCDirectoryServiceConfig {
243    type Output = dyn DirectoryService;
244    async fn build<'a>(
245        &'a self,
246        instance_name: &str,
247        _context: &CompositionContext,
248    ) -> Result<Arc<dyn DirectoryService>, Box<dyn std::error::Error + Send + Sync + 'static>> {
249        let client = proto::directory_service_client::DirectoryServiceClient::new(
250            crate::tonic::channel_from_url(&self.url.parse()?).await?,
251        );
252        Ok(Arc::new(GRPCDirectoryService::from_client(
253            instance_name.to_string(),
254            client,
255        )))
256    }
257}
258
259/// Allows uploading multiple Directory messages in the same gRPC stream.
260pub struct GRPCPutter {
261    /// Data about the current request - a handle to the task, and the tx part
262    /// of the channel.
263    /// The tx part of the pipe is used to send [proto::Directory] to the ongoing request.
264    /// The task will yield a [proto::PutDirectoryResponse] once the stream is closed.
265    #[allow(clippy::type_complexity)] // lol
266    rq: Option<(
267        JoinHandle<Result<proto::PutDirectoryResponse, Status>>,
268        UnboundedSender<proto::Directory>,
269    )>,
270}
271
272#[async_trait]
273impl DirectoryPutter for GRPCPutter {
274    #[instrument(level = "trace", skip_all, fields(directory.digest=%directory.digest()), err)]
275    async fn put(&mut self, directory: Directory) -> Result<(), crate::Error> {
276        match self.rq {
277            // If we're not already closed, send the directory to directory_sender.
278            Some((_, ref directory_sender)) => {
279                if directory_sender.send(directory.into()).is_err() {
280                    // If the channel has been prematurely closed, invoke close (so we can peek at the error code)
281                    // That error code is much more helpful, because it
282                    // contains the error message from the server.
283                    self.close().await?;
284                }
285                Ok(())
286            }
287            // If self.close() was already called, we can't put again.
288            None => Err(Error::StorageError(
289                "DirectoryPutter already closed".to_string(),
290            )),
291        }
292    }
293
294    /// Closes the stream for sending, and returns the value.
295    #[instrument(level = "trace", skip_all, ret, err)]
296    async fn close(&mut self) -> Result<B3Digest, crate::Error> {
297        // get self.rq, and replace it with None.
298        // This ensures we can only close it once.
299        match std::mem::take(&mut self.rq) {
300            None => Err(Error::StorageError("already closed".to_string())),
301            Some((task, directory_sender)) => {
302                // close directory_sender, so blocking on task will finish.
303                drop(directory_sender);
304
305                let root_digest = task
306                    .await?
307                    .map_err(|e| Error::StorageError(e.to_string()))?
308                    .root_digest;
309
310                root_digest.try_into().map_err(|_| {
311                    Error::StorageError("invalid root digest length in response".to_string())
312                })
313            }
314        }
315    }
316}
317
318#[cfg(test)]
319mod tests {
320    use std::time::Duration;
321    use tempfile::TempDir;
322    use tokio::net::UnixListener;
323    use tokio_retry::{Retry, strategy::ExponentialBackoff};
324    use tokio_stream::wrappers::UnixListenerStream;
325
326    use crate::{
327        directoryservice::{DirectoryService, GRPCDirectoryService, MemoryDirectoryService},
328        fixtures,
329        proto::{GRPCDirectoryServiceWrapper, directory_service_client::DirectoryServiceClient},
330    };
331
332    /// This ensures connecting via gRPC works as expected.
333    #[tokio::test]
334    async fn test_valid_unix_path_ping_pong() {
335        let tmpdir = TempDir::new().unwrap();
336        let socket_path = tmpdir.path().join("daemon");
337
338        let path_clone = socket_path.clone();
339
340        // Spin up a server
341        tokio::spawn(async {
342            let uds = UnixListener::bind(path_clone).unwrap();
343            let uds_stream = UnixListenerStream::new(uds);
344
345            // spin up a new server
346            let mut server = tonic::transport::Server::builder();
347            let router = server.add_service(
348                crate::proto::directory_service_server::DirectoryServiceServer::new(
349                    GRPCDirectoryServiceWrapper::new(
350                        Box::<MemoryDirectoryService>::default() as Box<dyn DirectoryService>
351                    ),
352                ),
353            );
354            router.serve_with_incoming(uds_stream).await
355        });
356
357        // wait for the socket to be created
358        Retry::spawn(
359            ExponentialBackoff::from_millis(20).max_delay(Duration::from_secs(10)),
360            || async {
361                if socket_path.exists() {
362                    Ok(())
363                } else {
364                    Err(())
365                }
366            },
367        )
368        .await
369        .expect("failed to wait for socket");
370
371        // prepare a client
372        let grpc_client = {
373            let url = url::Url::parse(&format!(
374                "grpc+unix://{}?wait-connect=1",
375                socket_path.display()
376            ))
377            .expect("must parse");
378            let client = DirectoryServiceClient::new(
379                crate::tonic::channel_from_url(&url)
380                    .await
381                    .expect("must succeed"),
382            );
383            GRPCDirectoryService::from_client("test-instance".into(), client)
384        };
385
386        assert!(
387            grpc_client
388                .get(&fixtures::DIRECTORY_A.digest())
389                .await
390                .expect("must not fail")
391                .is_none()
392        )
393    }
394}