nix_compat/nix_daemon/
handler.rs

1use std::{future::Future, ops::DerefMut, sync::Arc};
2
3use bytes::Bytes;
4use tokio::{
5    io::{split, AsyncReadExt, AsyncWriteExt, ReadHalf, WriteHalf},
6    sync::Mutex,
7};
8use tracing::{debug, warn};
9
10use super::{
11    framing::{NixFramedReader, StderrReadFramedReader},
12    types::{AddToStoreNarRequest, QueryValidPaths},
13    worker_protocol::{server_handshake_client, ClientSettings, Operation, Trust, STDERR_LAST},
14    NixDaemonIO,
15};
16
17use crate::{
18    store_path::StorePath,
19    wire::{
20        de::{NixRead, NixReader},
21        ser::{NixSerialize, NixWrite, NixWriter, NixWriterBuilder},
22        ProtocolVersion,
23    },
24};
25
26use crate::{nix_daemon::types::NixError, worker_protocol::STDERR_ERROR};
27
28/// Handles a single connection with a nix client.
29///
30/// As part of its [`initialization`] it performs the handshake with the client
31/// and determines the [ProtocolVersion] and [ClientSettings] to use for the remainder of the session.
32///
33/// Once initialized, [NixDaemon::handle_client] needs to be called to handle
34/// the rest of the session, it delegates all operation handling to an instance
35/// of [NixDaemonIO].
36///
37/// [`initialization`]: NixDaemon::initialize
38#[allow(dead_code)]
39pub struct NixDaemon<IO, R, W> {
40    io: Arc<IO>,
41    protocol_version: ProtocolVersion,
42    client_settings: ClientSettings,
43    reader: NixReader<R>,
44    writer: Arc<Mutex<NixWriter<W>>>,
45}
46
47impl<IO, R, W> NixDaemon<IO, R, W>
48where
49    IO: NixDaemonIO + Sync + Send,
50{
51    pub fn new(
52        io: Arc<IO>,
53        protocol_version: ProtocolVersion,
54        client_settings: ClientSettings,
55        reader: NixReader<R>,
56        writer: NixWriter<W>,
57    ) -> Self {
58        Self {
59            io,
60            protocol_version,
61            client_settings,
62            reader,
63            writer: Arc::new(Mutex::new(writer)),
64        }
65    }
66}
67
68impl<IO, RW> NixDaemon<IO, ReadHalf<RW>, WriteHalf<RW>>
69where
70    RW: AsyncReadExt + AsyncWriteExt + Send + Unpin + 'static,
71    IO: NixDaemonIO + Sync + Send,
72{
73    /// Async constructor for NixDaemon.
74    ///
75    /// Performs the initial handshake with the client and retrieves the client's preferred
76    /// settings.
77    ///
78    /// The resulting daemon can handle the client session by calling [NixDaemon::handle_client].
79    pub async fn initialize(io: Arc<IO>, mut connection: RW) -> Result<Self, std::io::Error>
80    where
81        RW: AsyncReadExt + AsyncWriteExt + Send + Unpin,
82    {
83        let protocol_version =
84            server_handshake_client(&mut connection, "2.18.2", Trust::Trusted).await?;
85
86        connection.write_u64_le(STDERR_LAST).await?;
87        let (reader, writer) = split(connection);
88        let mut reader = NixReader::builder()
89            .set_version(protocol_version)
90            .build(reader);
91        let mut writer = NixWriterBuilder::default()
92            .set_version(protocol_version)
93            .build(writer);
94
95        // The first op is always SetOptions
96        let operation: Operation = reader.read_value().await?;
97        if operation != Operation::SetOptions {
98            return Err(std::io::Error::other(
99                "Expected SetOptions operation, but got {operation}",
100            ));
101        }
102        let client_settings: ClientSettings = reader.read_value().await?;
103        writer.write_number(STDERR_LAST).await?;
104        writer.flush().await?;
105
106        Ok(Self::new(
107            io,
108            protocol_version,
109            client_settings,
110            reader,
111            writer,
112        ))
113    }
114
115    /// Main client connection loop, reads client's requests and responds to them accordingly.
116    pub async fn handle_client(&mut self) -> Result<(), std::io::Error> {
117        let io = self.io.clone();
118        loop {
119            let op_code = self.reader.read_number().await?;
120            let op = TryInto::<Operation>::try_into(op_code);
121            debug!(?op, "Received operation");
122            match op {
123                // Note: please keep operations sorted in ascending order of their numerical op number.
124                Ok(operation) => match operation {
125                    Operation::IsValidPath => {
126                        let path: StorePath<String> = self.reader.read_value().await?;
127                        Self::handle(&self.writer, io.is_valid_path(&path)).await?
128                    }
129                    // Note this operation does not currently delegate to NixDaemonIO,
130                    // The general idea is that we will pass relevant ClientSettings
131                    // into individual NixDaemonIO method calls if the need arises.
132                    // For now we just store the settings in the NixDaemon for future use.
133                    Operation::SetOptions => {
134                        self.client_settings = self.reader.read_value().await?;
135                        Self::handle(&self.writer, async { Ok(()) }).await?
136                    }
137                    Operation::QueryPathInfo => {
138                        let path: StorePath<String> = self.reader.read_value().await?;
139                        Self::handle(&self.writer, io.query_path_info(&path)).await?
140                    }
141                    Operation::QueryPathFromHashPart => {
142                        let hash: Bytes = self.reader.read_value().await?;
143                        Self::handle(&self.writer, io.query_path_from_hash_part(&hash)).await?
144                    }
145                    Operation::QueryValidPaths => {
146                        let query: QueryValidPaths = self.reader.read_value().await?;
147                        Self::handle(&self.writer, io.query_valid_paths(&query)).await?
148                    }
149                    Operation::QueryValidDerivers => {
150                        let path: StorePath<String> = self.reader.read_value().await?;
151                        Self::handle(&self.writer, io.query_valid_derivers(&path)).await?
152                    }
153                    // FUTUREWORK: These are just stubs that return an empty list.
154                    // It's important not to return an error for the local-overlay:// store
155                    // to work properly. While it will not see certain referrers and realizations
156                    // it will not fail on various operations like gc and optimize store. At the
157                    // same time, returning an empty list here shouldn't break any of local-overlay store's
158                    // invariants.
159                    Operation::QueryReferrers | Operation::QueryRealisation => {
160                        let _: String = self.reader.read_value().await?;
161                        Self::handle(&self.writer, async move {
162                            warn!(
163                                ?operation,
164                                "This operation is not implemented. Returning empty result..."
165                            );
166                            Ok(Vec::<StorePath<String>>::new())
167                        })
168                        .await?
169                    }
170                    Operation::AddToStoreNar => {
171                        let request: AddToStoreNarRequest = self.reader.read_value().await?;
172                        let minor_version = self.protocol_version.minor();
173                        match minor_version {
174                            ..21 => {
175                                // Before protocol version 1.21, the nar is sent unframed, so we just
176                                // pass the reader directly to the operation.
177                                Self::handle(
178                                    &self.writer,
179                                    self.io.add_to_store_nar(request, &mut self.reader),
180                                )
181                                .await?
182                            }
183                            21..23 => {
184                                // Protocol versions 1.21 .. 1.23 use STDERR_READ protocol, see logging.md#stderr_read.
185                                Self::handle(&self.writer, async {
186                                    let mut writer = self.writer.lock().await;
187                                    let mut reader = StderrReadFramedReader::new(
188                                        &mut self.reader,
189                                        writer.deref_mut(),
190                                    );
191                                    self.io.add_to_store_nar(request, &mut reader).await
192                                    // TODO(edef): enforce framing synchronisation
193                                })
194                                .await?
195                            }
196                            23.. => {
197                                // Starting at protocol version 1.23, the framed protocol is used, see serialization.md#framed
198                                let mut framed = NixFramedReader::new(&mut self.reader);
199
200                                Self::handle(&self.writer, async {
201                                    self.io.add_to_store_nar(request, &mut framed).await
202                                })
203                                .await?;
204
205                                // framing desynchronisation
206                                // this MUST kill the connection
207                                if !framed.is_eof() {
208                                    return Err(std::io::Error::new(
209                                        std::io::ErrorKind::InvalidData,
210                                        "payload was not fully consumed",
211                                    ));
212                                }
213                            }
214                        }
215                    }
216                    _ => {
217                        return Err(std::io::Error::other(format!(
218                            "Operation {operation:?} is not implemented"
219                        )));
220                    }
221                },
222                _ => {
223                    return Err(std::io::Error::other(format!(
224                        "Unknown operation code received: {op_code}"
225                    )));
226                }
227            }
228        }
229    }
230
231    /// Handles the operation and sends the response or error to the client.
232    ///
233    /// As per nix daemon protocol, after sending the request, the client expects zero or more
234    /// log lines/activities followed by either
235    /// * STDERR_LAST and the response bytes
236    /// * STDERR_ERROR and the error
237    ///
238    /// This is a helper method, awaiting on the passed in future and then
239    /// handling log lines/activities as described above.
240    async fn handle<T>(
241        writer: &Arc<Mutex<NixWriter<WriteHalf<RW>>>>,
242        future: impl Future<Output = std::io::Result<T>>,
243    ) -> Result<(), std::io::Error>
244    where
245        T: NixSerialize + Send,
246    {
247        let result = future.await;
248        let mut writer = writer.lock().await;
249
250        match result {
251            Ok(r) => {
252                // the protocol requires that we first indicate that we are done sending logs
253                // by sending STDERR_LAST and then the response.
254                writer.write_number(STDERR_LAST).await?;
255                writer.write_value(&r).await?;
256                writer.flush().await
257            }
258            Err(e) => {
259                debug!(err = ?e, "IO error");
260                writer.write_number(STDERR_ERROR).await?;
261                writer.write_value(&NixError::new(format!("{e:?}"))).await?;
262                writer.flush().await
263            }
264        }
265    }
266}
267
268#[cfg(test)]
269mod tests {
270    use super::*;
271    use std::{io::ErrorKind, sync::Arc};
272
273    use mockall::predicate;
274    use tokio::io::AsyncWriteExt;
275
276    use crate::{
277        nix_daemon::MockNixDaemonIO,
278        wire::ProtocolVersion,
279        worker_protocol::{ClientSettings, WORKER_MAGIC_1, WORKER_MAGIC_2},
280    };
281
282    #[tokio::test]
283    async fn test_daemon_initialization() {
284        let mut builder = tokio_test::io::Builder::new();
285        let test_conn = builder
286            .read(&WORKER_MAGIC_1.to_le_bytes())
287            .write(&WORKER_MAGIC_2.to_le_bytes())
288            // Our version is 1.37
289            .write(&[37, 1, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00])
290            // The client's versin is 1.35
291            .read(&[35, 1, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00])
292            // cpu affinity
293            .read(&[0; 8])
294            // reservespace
295            .read(&[0; 8])
296            // version (size)
297            .write(&[0x06, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00])
298            // version (data == 2.18.2 + padding)
299            .write(&[50, 46, 49, 56, 46, 50, 0, 0])
300            // Trusted (1 == client trusted)
301            .write(&[1, 0, 0, 0, 0, 0, 0, 0])
302            // STDERR_LAST
303            .write(&[115, 116, 108, 97, 0, 0, 0, 0]);
304
305        let mut bytes = Vec::new();
306        let mut writer = NixWriter::new(&mut bytes);
307        writer
308            .write_value(&ClientSettings::default())
309            .await
310            .unwrap();
311        writer.flush().await.unwrap();
312
313        let test_conn = test_conn
314            // SetOptions op
315            .read(&[19, 0, 0, 0, 0, 0, 0, 0])
316            .read(&bytes)
317            // STDERR_LAST
318            .write(&[115, 116, 108, 97, 0, 0, 0, 0])
319            .build();
320
321        let mock = MockNixDaemonIO::new();
322        let daemon = NixDaemon::initialize(Arc::new(mock), test_conn)
323            .await
324            .unwrap();
325        assert_eq!(daemon.client_settings, ClientSettings::default());
326        assert_eq!(daemon.protocol_version, ProtocolVersion::from_parts(1, 35));
327    }
328
329    async fn serialize<T>(req: &T, protocol_version: ProtocolVersion) -> Vec<u8>
330    where
331        T: NixSerialize + Send,
332    {
333        let mut result: Vec<u8> = Vec::new();
334        let mut w = NixWriter::builder()
335            .set_version(protocol_version)
336            .build(&mut result);
337        w.write_value(req).await.unwrap();
338        w.flush().await.unwrap();
339        result
340    }
341
342    async fn respond<T>(
343        resp: &Result<T, std::io::Error>,
344        protocol_version: ProtocolVersion,
345    ) -> Vec<u8>
346    where
347        T: NixSerialize + Send,
348    {
349        let mut result: Vec<u8> = Vec::new();
350        let mut w = NixWriter::builder()
351            .set_version(protocol_version)
352            .build(&mut result);
353        match resp {
354            Ok(value) => {
355                w.write_value(&STDERR_LAST).await.unwrap();
356                w.write_value(value).await.unwrap();
357            }
358            Err(e) => {
359                w.write_value(&STDERR_ERROR).await.unwrap();
360                w.write_value(&NixError::new(format!("{:?}", e)))
361                    .await
362                    .unwrap();
363            }
364        }
365        w.flush().await.unwrap();
366        result
367    }
368
369    #[tokio::test]
370    async fn test_handle_is_valid_path_ok() {
371        let version = ProtocolVersion::from_parts(1, 37);
372        let (io, mut handle) = tokio_test::io::Builder::new().build_with_handle();
373        let mut mock = MockNixDaemonIO::new();
374        let (reader, writer) = split(io);
375        let path: StorePath<String> = StorePath::<String>::from_absolute_path(
376            "/nix/store/33l4p0pn0mybmqzaxfkpppyh7vx1c74p-hello-2.12.1".as_bytes(),
377        )
378        .unwrap();
379        mock.expect_is_valid_path()
380            .with(predicate::eq(path.clone()))
381            .times(1)
382            .returning(|_| Box::pin(async { Ok(true) }));
383
384        handle.read(&Into::<u64>::into(Operation::IsValidPath).to_le_bytes());
385        handle.read(&serialize(&path, version).await);
386        handle.write(&respond(&Ok(true), version).await);
387        drop(handle);
388
389        let mut daemon = NixDaemon::new(
390            Arc::new(mock),
391            version,
392            ClientSettings::default(),
393            NixReader::new(reader),
394            NixWriter::new(writer),
395        );
396        assert_eq!(
397            ErrorKind::UnexpectedEof,
398            daemon
399                .handle_client()
400                .await
401                .expect_err("Expecting eof")
402                .kind()
403        );
404    }
405
406    #[tokio::test]
407    async fn test_handle_is_valid_path_err() {
408        let version = ProtocolVersion::from_parts(1, 37);
409        let (io, mut handle) = tokio_test::io::Builder::new().build_with_handle();
410        let mut mock = MockNixDaemonIO::new();
411        let (reader, writer) = split(io);
412        let path: StorePath<String> = StorePath::<String>::from_absolute_path(
413            "/nix/store/33l4p0pn0mybmqzaxfkpppyh7vx1c74p-hello-2.12.1".as_bytes(),
414        )
415        .unwrap();
416        mock.expect_is_valid_path()
417            .with(predicate::eq(path.clone()))
418            .times(1)
419            .returning(|_| Box::pin(async { Err(std::io::Error::other("hello")) }));
420
421        handle.read(&Into::<u64>::into(Operation::IsValidPath).to_le_bytes());
422        handle.read(&serialize(&path, version).await);
423        handle.write(&respond::<bool>(&Err(std::io::Error::other("hello")), version).await);
424        drop(handle);
425
426        let mut daemon = NixDaemon::new(
427            Arc::new(mock),
428            version,
429            ClientSettings::default(),
430            NixReader::new(reader),
431            NixWriter::new(writer),
432        );
433        assert_eq!(
434            ErrorKind::UnexpectedEof,
435            daemon
436                .handle_client()
437                .await
438                .expect_err("Expecting eof")
439                .kind()
440        );
441    }
442}