snix_castore/directoryservice/
grpc.rs

1use std::collections::HashSet;
2
3use super::{Directory, DirectoryPutter, DirectoryService};
4use crate::composition::{CompositionContext, ServiceBuilder};
5use crate::proto::{self, get_directory_request::ByWhat};
6use crate::{B3Digest, DirectoryError, Error};
7use async_stream::try_stream;
8use futures::stream::BoxStream;
9use std::sync::Arc;
10use tokio::spawn;
11use tokio::sync::mpsc::UnboundedSender;
12use tokio::task::JoinHandle;
13use tokio_stream::wrappers::UnboundedReceiverStream;
14use tonic::{Code, Status, async_trait};
15use tracing::{Instrument as _, instrument, warn};
16
17/// Connects to a (remote) snix-store DirectoryService over gRPC.
18#[derive(Clone)]
19pub struct GRPCDirectoryService<T> {
20    instance_name: String,
21    /// The internal reference to a gRPC client.
22    /// Cloning it is cheap, and it internally handles concurrent requests.
23    grpc_client: proto::directory_service_client::DirectoryServiceClient<T>,
24}
25
26impl<T> GRPCDirectoryService<T> {
27    /// construct a [GRPCDirectoryService] from a [proto::directory_service_client::DirectoryServiceClient].
28    /// panics if called outside the context of a tokio runtime.
29    pub fn from_client(
30        instance_name: String,
31        grpc_client: proto::directory_service_client::DirectoryServiceClient<T>,
32    ) -> Self {
33        Self {
34            instance_name,
35            grpc_client,
36        }
37    }
38}
39
40#[async_trait]
41impl<T> DirectoryService for GRPCDirectoryService<T>
42where
43    T: tonic::client::GrpcService<tonic::body::BoxBody> + Send + Sync + Clone + 'static,
44    T::ResponseBody: tonic::codegen::Body<Data = tonic::codegen::Bytes> + Send + 'static,
45    <T::ResponseBody as tonic::codegen::Body>::Error: Into<tonic::codegen::StdError> + Send,
46    T::Future: Send,
47{
48    #[instrument(level = "trace", skip_all, fields(directory.digest = %digest, instance_name = %self.instance_name))]
49    async fn get(&self, digest: &B3Digest) -> Result<Option<Directory>, crate::Error> {
50        // Get a new handle to the gRPC client, and copy the digest.
51        let mut grpc_client = self.grpc_client.clone();
52        let digest_cpy = digest.clone();
53        let message = async move {
54            let mut s = grpc_client
55                .get(proto::GetDirectoryRequest {
56                    recursive: false,
57                    by_what: Some(ByWhat::Digest(digest_cpy.into())),
58                })
59                .await?
60                .into_inner();
61
62            // Retrieve the first message only, then close the stream (we set recursive to false)
63            s.message().await
64        };
65
66        let digest = digest.clone();
67        match message.await {
68            Ok(Some(directory)) => {
69                // Validate the retrieved Directory indeed has the
70                // digest we expect it to have, to detect corruptions.
71                let actual_digest = directory.digest();
72                if actual_digest != digest {
73                    Err(crate::Error::StorageError(format!(
74                        "requested directory with digest {digest}, but got {actual_digest}"
75                    )))
76                } else {
77                    Ok(Some(directory.try_into().map_err(|_| {
78                        Error::StorageError("invalid root digest length in response".to_string())
79                    })?))
80                }
81            }
82            Ok(None) => Ok(None),
83            Err(e) if e.code() == Code::NotFound => Ok(None),
84            Err(e) => Err(crate::Error::StorageError(e.to_string())),
85        }
86    }
87
88    #[instrument(level = "trace", skip_all, fields(directory.digest = %directory.digest(), instance_name = %self.instance_name))]
89    async fn put(&self, directory: Directory) -> Result<B3Digest, crate::Error> {
90        let resp = self
91            .grpc_client
92            .clone()
93            .put(tokio_stream::once(proto::Directory::from(directory)))
94            .await;
95
96        match resp {
97            Ok(put_directory_resp) => Ok(put_directory_resp
98                .into_inner()
99                .root_digest
100                .try_into()
101                .map_err(|_| {
102                    Error::StorageError("invalid root digest length in response".to_string())
103                })?),
104            Err(e) => Err(crate::Error::StorageError(e.to_string())),
105        }
106    }
107
108    #[instrument(level = "trace", skip_all, fields(directory.digest = %root_directory_digest, instance_name = %self.instance_name))]
109    fn get_recursive(
110        &self,
111        root_directory_digest: &B3Digest,
112    ) -> BoxStream<'static, Result<Directory, Error>> {
113        let mut grpc_client = self.grpc_client.clone();
114        let root_directory_digest = root_directory_digest.clone();
115
116        let stream = try_stream! {
117            let mut stream = grpc_client
118                .get(proto::GetDirectoryRequest {
119                    recursive: true,
120                    by_what: Some(ByWhat::Digest(root_directory_digest.clone().into())),
121                })
122                .await
123                .map_err(|e| crate::Error::StorageError(e.to_string()))?
124                .into_inner();
125
126            // The Directory digests we received so far
127            let mut received_directory_digests: HashSet<B3Digest> = HashSet::new();
128            // The Directory digests we're still expecting to get sent.
129            let mut expected_directory_digests: HashSet<B3Digest> = HashSet::from([root_directory_digest.clone()]);
130
131            loop {
132                match stream.message().await {
133                    Ok(Some(directory)) => {
134                        // validate we actually expected that directory, and move it from expected to received.
135                        let directory_digest = directory.digest();
136                        let was_expected = expected_directory_digests.remove(&directory_digest);
137                        if !was_expected {
138                            // FUTUREWORK: dumb clients might send the same stuff twice.
139                            // as a fallback, we might want to tolerate receiving
140                            // it if it's in received_directory_digests (as that
141                            // means it once was in expected_directory_digests)
142                            Err(crate::Error::StorageError(format!(
143                                "received unexpected directory {directory_digest}"
144                            )))?;
145                        }
146                        received_directory_digests.insert(directory_digest);
147
148                        // register all children in expected_directory_digests.
149                        for child_directory in &directory.directories {
150                            // We ran validate() above, so we know these digests must be correct.
151                            let child_directory_digest =
152                                child_directory.digest.clone().try_into().unwrap();
153
154                            expected_directory_digests
155                                .insert(child_directory_digest);
156                        }
157
158                        let directory = directory.try_into()
159                            .map_err(|e: DirectoryError| Error::StorageError(e.to_string()))?;
160
161                        yield directory;
162                    },
163                    Ok(None) if expected_directory_digests.len() == 1 && expected_directory_digests.contains(&root_directory_digest) => {
164                        // The root directory of the requested closure was not found, return an
165                        // empty stream
166                        return
167                    }
168                    Ok(None) => {
169                        // The stream has ended
170                        let diff_len = expected_directory_digests
171                            // Account for directories which have been referenced more than once,
172                            // but only received once since they were deduplicated
173                            .difference(&received_directory_digests)
174                            .count();
175                        // If this is not empty, then the closure is incomplete
176                        if diff_len != 0 {
177                            Err(crate::Error::StorageError(format!(
178                                "still expected {diff_len} directories, but got premature end of stream"
179                            )))?
180                        } else {
181                            return
182                        }
183                    },
184                    Err(e) => {
185                        Err(crate::Error::StorageError(e.to_string()))?;
186                    },
187                }
188            }
189        };
190
191        Box::pin(stream)
192    }
193
194    #[instrument(skip_all)]
195    fn put_multiple_start(&self) -> Box<(dyn DirectoryPutter + 'static)> {
196        let (tx, rx) = tokio::sync::mpsc::unbounded_channel();
197
198        let task = spawn({
199            let mut grpc_client = self.grpc_client.clone();
200
201            async move {
202                Ok::<_, Status>(
203                    grpc_client
204                        .put(UnboundedReceiverStream::new(rx))
205                        .await?
206                        .into_inner(),
207                )
208            }
209            // instrument the task with the current span, this is not done by default
210            .in_current_span()
211        });
212
213        Box::new(GRPCPutter {
214            rq: Some((task, tx)),
215        })
216    }
217}
218
219#[derive(serde::Deserialize, Debug)]
220#[serde(deny_unknown_fields)]
221pub struct GRPCDirectoryServiceConfig {
222    url: String,
223}
224
225impl TryFrom<url::Url> for GRPCDirectoryServiceConfig {
226    type Error = Box<dyn std::error::Error + Send + Sync>;
227    fn try_from(url: url::Url) -> Result<Self, Self::Error> {
228        //   This is normally grpc+unix for unix sockets, and grpc+http(s) for the HTTP counterparts.
229        // - In the case of unix sockets, there must be a path, but may not be a host.
230        // - In the case of non-unix sockets, there must be a host, but no path.
231        // Constructing the channel is handled by snix_castore::channel::from_url.
232        Ok(GRPCDirectoryServiceConfig {
233            url: url.to_string(),
234        })
235    }
236}
237
238#[async_trait]
239impl ServiceBuilder for GRPCDirectoryServiceConfig {
240    type Output = dyn DirectoryService;
241    async fn build<'a>(
242        &'a self,
243        instance_name: &str,
244        _context: &CompositionContext,
245    ) -> Result<Arc<dyn DirectoryService>, Box<dyn std::error::Error + Send + Sync + 'static>> {
246        let client = proto::directory_service_client::DirectoryServiceClient::with_interceptor(
247            crate::tonic::channel_from_url(&self.url.parse()?).await?,
248            // tonic::service::Interceptor wants an unboxed Status as return type.
249            // https://github.com/hyperium/tonic/issues/2253
250            |rq| snix_tracing::propagate::tonic::send_trace(rq).map_err(|e| *e),
251        );
252        Ok(Arc::new(GRPCDirectoryService::from_client(
253            instance_name.to_string(),
254            client,
255        )))
256    }
257}
258
259/// Allows uploading multiple Directory messages in the same gRPC stream.
260pub struct GRPCPutter {
261    /// Data about the current request - a handle to the task, and the tx part
262    /// of the channel.
263    /// The tx part of the pipe is used to send [proto::Directory] to the ongoing request.
264    /// The task will yield a [proto::PutDirectoryResponse] once the stream is closed.
265    #[allow(clippy::type_complexity)] // lol
266    rq: Option<(
267        JoinHandle<Result<proto::PutDirectoryResponse, Status>>,
268        UnboundedSender<proto::Directory>,
269    )>,
270}
271
272#[async_trait]
273impl DirectoryPutter for GRPCPutter {
274    #[instrument(level = "trace", skip_all, fields(directory.digest=%directory.digest()), err)]
275    async fn put(&mut self, directory: Directory) -> Result<(), crate::Error> {
276        match self.rq {
277            // If we're not already closed, send the directory to directory_sender.
278            Some((_, ref directory_sender)) => {
279                if directory_sender.send(directory.into()).is_err() {
280                    // If the channel has been prematurely closed, invoke close (so we can peek at the error code)
281                    // That error code is much more helpful, because it
282                    // contains the error message from the server.
283                    self.close().await?;
284                }
285                Ok(())
286            }
287            // If self.close() was already called, we can't put again.
288            None => Err(Error::StorageError(
289                "DirectoryPutter already closed".to_string(),
290            )),
291        }
292    }
293
294    /// Closes the stream for sending, and returns the value.
295    #[instrument(level = "trace", skip_all, ret, err)]
296    async fn close(&mut self) -> Result<B3Digest, crate::Error> {
297        // get self.rq, and replace it with None.
298        // This ensures we can only close it once.
299        match std::mem::take(&mut self.rq) {
300            None => Err(Error::StorageError("already closed".to_string())),
301            Some((task, directory_sender)) => {
302                // close directory_sender, so blocking on task will finish.
303                drop(directory_sender);
304
305                let root_digest = task
306                    .await?
307                    .map_err(|e| Error::StorageError(e.to_string()))?
308                    .root_digest;
309
310                root_digest.try_into().map_err(|_| {
311                    Error::StorageError("invalid root digest length in response".to_string())
312                })
313            }
314        }
315    }
316}
317
318#[cfg(test)]
319mod tests {
320    use std::time::Duration;
321    use tempfile::TempDir;
322    use tokio::net::UnixListener;
323    use tokio_retry::{Retry, strategy::ExponentialBackoff};
324    use tokio_stream::wrappers::UnixListenerStream;
325
326    use crate::{
327        directoryservice::{DirectoryService, GRPCDirectoryService, MemoryDirectoryService},
328        fixtures,
329        proto::{GRPCDirectoryServiceWrapper, directory_service_client::DirectoryServiceClient},
330    };
331
332    /// This ensures connecting via gRPC works as expected.
333    #[tokio::test]
334    async fn test_valid_unix_path_ping_pong() {
335        let tmpdir = TempDir::new().unwrap();
336        let socket_path = tmpdir.path().join("daemon");
337
338        let path_clone = socket_path.clone();
339
340        // Spin up a server
341        tokio::spawn(async {
342            let uds = UnixListener::bind(path_clone).unwrap();
343            let uds_stream = UnixListenerStream::new(uds);
344
345            // spin up a new server
346            let mut server = tonic::transport::Server::builder();
347            let router = server.add_service(
348                crate::proto::directory_service_server::DirectoryServiceServer::new(
349                    GRPCDirectoryServiceWrapper::new(
350                        Box::<MemoryDirectoryService>::default() as Box<dyn DirectoryService>
351                    ),
352                ),
353            );
354            router.serve_with_incoming(uds_stream).await
355        });
356
357        // wait for the socket to be created
358        Retry::spawn(
359            ExponentialBackoff::from_millis(20).max_delay(Duration::from_secs(10)),
360            || async {
361                if socket_path.exists() {
362                    Ok(())
363                } else {
364                    Err(())
365                }
366            },
367        )
368        .await
369        .expect("failed to wait for socket");
370
371        // prepare a client
372        let grpc_client = {
373            let url = url::Url::parse(&format!(
374                "grpc+unix://{}?wait-connect=1",
375                socket_path.display()
376            ))
377            .expect("must parse");
378            let client = DirectoryServiceClient::new(
379                crate::tonic::channel_from_url(&url)
380                    .await
381                    .expect("must succeed"),
382            );
383            GRPCDirectoryService::from_client("test-instance".into(), client)
384        };
385
386        assert!(
387            grpc_client
388                .get(&fixtures::DIRECTORY_A.digest())
389                .await
390                .expect("must not fail")
391                .is_none()
392        )
393    }
394}