Skip to main content

snix_castore/directoryservice/
grpc.rs

1use super::{Directory, DirectoryPutter, DirectoryService};
2use crate::B3Digest;
3use crate::composition::{CompositionContext, ServiceBuilder};
4use crate::directoryservice::RootToLeavesValidator;
5use crate::proto::{self, get_directory_request::ByWhat};
6use futures::StreamExt;
7use futures::stream::BoxStream;
8use std::sync::Arc;
9use tokio::spawn;
10use tokio::sync::mpsc::UnboundedSender;
11use tokio::task::JoinHandle;
12use tokio_stream::wrappers::UnboundedReceiverStream;
13use tonic::{Code, Status, async_trait};
14use tracing::{Instrument as _, instrument, warn};
15
16/// Connects to a (remote) snix-store DirectoryService over gRPC.
17#[derive(Clone)]
18pub struct GRPCDirectoryService<T> {
19    instance_name: String,
20    /// The internal reference to a gRPC client.
21    /// Cloning it is cheap, and it internally handles concurrent requests.
22    grpc_client: proto::directory_service_client::DirectoryServiceClient<T>,
23}
24
25impl<T> GRPCDirectoryService<T> {
26    /// construct a [GRPCDirectoryService] from a [proto::directory_service_client::DirectoryServiceClient].
27    /// panics if called outside the context of a tokio runtime.
28    pub fn from_client(
29        instance_name: String,
30        grpc_client: proto::directory_service_client::DirectoryServiceClient<T>,
31    ) -> Self {
32        Self {
33            instance_name,
34            grpc_client,
35        }
36    }
37}
38
39#[async_trait]
40impl<T> DirectoryService for GRPCDirectoryService<T>
41where
42    T: tonic::client::GrpcService<tonic::body::Body> + Send + Sync + Clone + 'static,
43    T::ResponseBody: tonic::codegen::Body<Data = tonic::codegen::Bytes> + Send + 'static,
44    <T::ResponseBody as tonic::codegen::Body>::Error: Into<tonic::codegen::StdError> + Send,
45    T::Future: Send,
46{
47    #[instrument(level = "trace", skip_all, fields(directory.digest = %digest, instance_name = %self.instance_name))]
48    async fn get(&self, digest: &B3Digest) -> Result<Option<Directory>, super::Error> {
49        // clone the client, as it takes a &mut.
50        // We retrieve the first message only, then close the stream (we set recursive to false)
51        match self
52            .grpc_client
53            .clone()
54            .get(proto::GetDirectoryRequest {
55                recursive: false,
56                by_what: Some(ByWhat::Digest((*digest).into())),
57            })
58            .await
59            .map_err(Error::Tonic)?
60            .into_inner()
61            .message()
62            .await
63        {
64            Ok(Some(proto_directory)) => {
65                // Validate the retrieved Directory indeed has the
66                // digest we expect it to have, to detect corruptions.
67                let actual_digest = proto_directory.digest();
68                if &actual_digest != digest {
69                    Err(Error::WrongDigest {
70                        expected: *digest,
71                        actual: actual_digest,
72                    })?
73                } else {
74                    let directory =
75                        Directory::try_from(proto_directory).map_err(Error::DirectoryValidation)?;
76                    Ok(Some(directory))
77                }
78            }
79            Ok(None) => Ok(None),
80            Err(e) if e.code() == Code::NotFound => Ok(None),
81            Err(e) => Err(Error::Tonic(e))?,
82        }
83    }
84
85    #[instrument(level = "trace", skip_all, fields(directory.digest = %directory.digest(), instance_name = %self.instance_name))]
86    async fn put(&self, directory: Directory) -> Result<B3Digest, super::Error> {
87        let resp = self
88            .grpc_client
89            .clone()
90            .put(tokio_stream::once(proto::Directory::from(directory)))
91            .await
92            .map_err(Error::Tonic)?;
93
94        let digest = resp
95            .into_inner()
96            .root_digest
97            .try_into()
98            .map_err(|_| Error::InvalidDigestLen)?;
99
100        Ok(digest)
101    }
102
103    #[instrument(level = "trace", skip_all, fields(directory.digest = %root_directory_digest, instance_name = %self.instance_name))]
104    fn get_recursive(
105        &self,
106        root_directory_digest: &B3Digest,
107    ) -> BoxStream<'static, Result<Directory, super::Error>> {
108        let mut grpc_client = self.grpc_client.clone();
109        let root_directory_digest = *root_directory_digest;
110
111        let mut order_validator =
112            RootToLeavesValidator::new_with_root_digest(root_directory_digest);
113
114        async_stream::try_stream! {
115            let mut directories = grpc_client
116                .get(proto::GetDirectoryRequest {
117                    recursive: true,
118                    by_what: Some(ByWhat::Digest((root_directory_digest).into())),
119                })
120                .await
121                .map_err(Error::Tonic)?
122                .into_inner();
123
124            while let Some(proto_directory) = directories.message().await.map_err(Error::Tonic)? {
125                let directory = Directory::try_from(proto_directory).map_err(Error::DirectoryValidation)?;
126                order_validator.try_accept(&directory).map_err(Error::DirectoryOrdering)?;
127
128                yield directory;
129            }
130        }
131        .boxed()
132    }
133
134    #[instrument(skip_all)]
135    fn put_multiple_start(&self) -> Box<dyn DirectoryPutter + 'static> {
136        let (tx, rx) = tokio::sync::mpsc::unbounded_channel();
137
138        let task = spawn({
139            let mut grpc_client = self.grpc_client.clone();
140
141            async move {
142                Ok::<_, Status>(
143                    grpc_client
144                        .put(UnboundedReceiverStream::new(rx))
145                        .await?
146                        .into_inner(),
147                )
148            }
149            // instrument the task with the current span, this is not done by default
150            .in_current_span()
151        });
152
153        Box::new(GRPCPutter {
154            rq: Some((task, tx)),
155        })
156    }
157}
158
159/// Allows uploading multiple Directory messages in the same gRPC stream.
160pub struct GRPCPutter {
161    /// Data about the current request - a handle to the task, and the tx part
162    /// of the channel.
163    /// The tx part of the pipe is used to send [proto::Directory] to the ongoing request.
164    /// The task will yield a [proto::PutDirectoryResponse] once the stream is closed.
165    #[allow(clippy::type_complexity)] // lol
166    rq: Option<(
167        JoinHandle<Result<proto::PutDirectoryResponse, Status>>,
168        UnboundedSender<proto::Directory>,
169    )>,
170}
171
172#[async_trait]
173impl DirectoryPutter for GRPCPutter {
174    #[instrument(level = "trace", skip_all, fields(directory.digest=%directory.digest()), err)]
175    async fn put(&mut self, directory: Directory) -> Result<(), super::Error> {
176        let (_, directory_sender) = self
177            .rq
178            .as_ref()
179            .ok_or_else(|| Error::DirectoryPutterAlreadyClosed)?;
180        // If we're not already closed, send the directory to directory_sender.
181        if directory_sender.send(directory.into()).is_err() {
182            // If the channel has been prematurely closed, invoke close (so we can peek at the error code)
183            // That error code is much more helpful, because it
184            // contains the error message from the server.
185            self.close().await?;
186        }
187        Ok(())
188    }
189
190    /// Closes the stream for sending, and returns the value.
191    #[instrument(level = "trace", skip_all, ret, err)]
192    async fn close(&mut self) -> Result<B3Digest, super::Error> {
193        // get self.rq, and replace it with None.
194        // This ensures we can only close it once.
195        let (task, directory_sender) =
196            std::mem::take(&mut self.rq).ok_or_else(|| Error::DirectoryPutterAlreadyClosed)?;
197
198        // close directory_sender, so blocking on task will finish.
199        drop(directory_sender);
200
201        let resp = task
202            .await
203            .map_err(Error::TokioJoin)?
204            .map_err(Error::Tonic)?;
205
206        Ok(B3Digest::try_from(resp.root_digest).map_err(|_| Error::InvalidDigestLen)?)
207    }
208}
209
210#[derive(thiserror::Error, Debug)]
211pub enum Error {
212    #[error("Directory Graph ordering error: {0}")]
213    DirectoryOrdering(#[from] crate::directoryservice::OrderingError),
214
215    #[error("DirectoryPutter already closed")]
216    DirectoryPutterAlreadyClosed,
217
218    #[error("requested directory has wrong digest, expected {expected}, actual {actual}")]
219    WrongDigest {
220        expected: B3Digest,
221        actual: B3Digest,
222    },
223
224    #[error("tonic status: {0}")]
225    Tonic(#[from] tonic::Status),
226
227    #[error("invalid digest length returned from put")]
228    InvalidDigestLen,
229
230    #[error("failed to decode protobuf: {0}")]
231    ProtobufDecode(#[from] prost::DecodeError),
232    #[error("failed to validate directory: {0}")]
233    DirectoryValidation(#[from] crate::DirectoryError),
234
235    #[error("join error: {0}")]
236    TokioJoin(#[from] tokio::task::JoinError),
237    #[error("io error: {0}")]
238    IO(#[from] std::io::Error),
239}
240
241impl From<Error> for super::Error {
242    fn from(value: Error) -> Self {
243        Self(Box::new(value))
244    }
245}
246
247#[derive(serde::Deserialize, Debug)]
248#[serde(deny_unknown_fields)]
249pub struct GRPCDirectoryServiceConfig {
250    url: String,
251}
252
253impl TryFrom<url::Url> for GRPCDirectoryServiceConfig {
254    type Error = Box<dyn std::error::Error + Send + Sync>;
255    fn try_from(url: url::Url) -> Result<Self, Self::Error> {
256        //   This is normally grpc+unix for unix sockets, and grpc+http(s) for the HTTP counterparts.
257        // - In the case of unix sockets, there must be a path, but may not be a host.
258        // - In the case of non-unix sockets, there must be a host, but no path.
259        // Constructing the channel is handled by snix_castore::channel::from_url.
260        Ok(GRPCDirectoryServiceConfig {
261            url: url.to_string(),
262        })
263    }
264}
265
266#[async_trait]
267impl ServiceBuilder for GRPCDirectoryServiceConfig {
268    type Output = dyn DirectoryService;
269    async fn build<'a>(
270        &'a self,
271        instance_name: &str,
272        _context: &CompositionContext,
273    ) -> Result<Arc<Self::Output>, Box<dyn std::error::Error + Send + Sync>> {
274        let client = proto::directory_service_client::DirectoryServiceClient::with_interceptor(
275            crate::tonic::channel_from_url(&self.url.parse()?).await?,
276            snix_tracing::propagate::tonic::send_trace,
277        );
278        Ok(Arc::new(GRPCDirectoryService::from_client(
279            instance_name.to_string(),
280            client,
281        )))
282    }
283}
284#[cfg(test)]
285mod tests {
286    use std::time::Duration;
287    use tempfile::TempDir;
288    use tokio::net::UnixListener;
289    use tokio_retry::{Retry, strategy::ExponentialBackoff};
290    use tokio_stream::wrappers::UnixListenerStream;
291
292    use crate::{
293        directoryservice::{DirectoryService, GRPCDirectoryService},
294        fixtures,
295        proto::{GRPCDirectoryServiceWrapper, directory_service_client::DirectoryServiceClient},
296        utils::gen_test_directory_service,
297    };
298
299    /// This ensures connecting via gRPC works as expected.
300    #[tokio::test]
301    async fn test_valid_unix_path_ping_pong() {
302        let tmpdir = TempDir::new().unwrap();
303        let socket_path = tmpdir.path().join("daemon");
304
305        let path_clone = socket_path.clone();
306
307        // Spin up a server
308        tokio::spawn(async {
309            let uds = UnixListener::bind(path_clone).unwrap();
310            let uds_stream = UnixListenerStream::new(uds);
311
312            // spin up a new server
313            let mut server = tonic::transport::Server::builder();
314            let router = server.add_service(
315                crate::proto::directory_service_server::DirectoryServiceServer::new(
316                    GRPCDirectoryServiceWrapper::new(Box::new(gen_test_directory_service())),
317                ),
318            );
319            router.serve_with_incoming(uds_stream).await
320        });
321
322        // wait for the socket to be created
323        Retry::spawn(
324            ExponentialBackoff::from_millis(20).max_delay(Duration::from_secs(10)),
325            || async {
326                if socket_path.exists() {
327                    Ok(())
328                } else {
329                    Err(())
330                }
331            },
332        )
333        .await
334        .expect("failed to wait for socket");
335
336        // prepare a client
337        let grpc_client = {
338            let url = url::Url::parse(&format!(
339                "grpc+unix:{}?wait-connect=1",
340                socket_path.display()
341            ))
342            .expect("must parse");
343            let client = DirectoryServiceClient::new(
344                crate::tonic::channel_from_url(&url)
345                    .await
346                    .expect("must succeed"),
347            );
348            GRPCDirectoryService::from_client("test-instance".into(), client)
349        };
350
351        assert!(
352            grpc_client
353                .get(&fixtures::DIRECTORY_A.digest())
354                .await
355                .expect("must not fail")
356                .is_none()
357        )
358    }
359}