1use std::{future::Future, ops::DerefMut, sync::Arc};
2
3use bytes::Bytes;
4use tokio::{
5 io::{split, AsyncReadExt, AsyncWriteExt, ReadHalf, WriteHalf},
6 sync::Mutex,
7};
8use tracing::{debug, warn};
9
10use super::{
11 framing::{NixFramedReader, StderrReadFramedReader},
12 types::{AddToStoreNarRequest, QueryValidPaths},
13 worker_protocol::{server_handshake_client, ClientSettings, Operation, Trust, STDERR_LAST},
14 NixDaemonIO,
15};
16
17use crate::{
18 store_path::StorePath,
19 wire::{
20 de::{NixRead, NixReader},
21 ser::{NixSerialize, NixWrite, NixWriter, NixWriterBuilder},
22 ProtocolVersion,
23 },
24};
25
26use crate::{nix_daemon::types::NixError, worker_protocol::STDERR_ERROR};
27
28#[allow(dead_code)]
39pub struct NixDaemon<IO, R, W> {
40 io: Arc<IO>,
41 protocol_version: ProtocolVersion,
42 client_settings: ClientSettings,
43 reader: NixReader<R>,
44 writer: Arc<Mutex<NixWriter<W>>>,
45}
46
47impl<IO, R, W> NixDaemon<IO, R, W>
48where
49 IO: NixDaemonIO + Sync + Send,
50{
51 pub fn new(
52 io: Arc<IO>,
53 protocol_version: ProtocolVersion,
54 client_settings: ClientSettings,
55 reader: NixReader<R>,
56 writer: NixWriter<W>,
57 ) -> Self {
58 Self {
59 io,
60 protocol_version,
61 client_settings,
62 reader,
63 writer: Arc::new(Mutex::new(writer)),
64 }
65 }
66}
67
68impl<IO, RW> NixDaemon<IO, ReadHalf<RW>, WriteHalf<RW>>
69where
70 RW: AsyncReadExt + AsyncWriteExt + Send + Unpin + 'static,
71 IO: NixDaemonIO + Sync + Send,
72{
73 pub async fn initialize(io: Arc<IO>, mut connection: RW) -> Result<Self, std::io::Error>
80 where
81 RW: AsyncReadExt + AsyncWriteExt + Send + Unpin,
82 {
83 let protocol_version =
84 server_handshake_client(&mut connection, "2.18.2", Trust::Trusted).await?;
85
86 connection.write_u64_le(STDERR_LAST).await?;
87 let (reader, writer) = split(connection);
88 let mut reader = NixReader::builder()
89 .set_version(protocol_version)
90 .build(reader);
91 let mut writer = NixWriterBuilder::default()
92 .set_version(protocol_version)
93 .build(writer);
94
95 let operation: Operation = reader.read_value().await?;
97 if operation != Operation::SetOptions {
98 return Err(std::io::Error::other(
99 "Expected SetOptions operation, but got {operation}",
100 ));
101 }
102 let client_settings: ClientSettings = reader.read_value().await?;
103 writer.write_number(STDERR_LAST).await?;
104 writer.flush().await?;
105
106 Ok(Self::new(
107 io,
108 protocol_version,
109 client_settings,
110 reader,
111 writer,
112 ))
113 }
114
115 pub async fn handle_client(&mut self) -> Result<(), std::io::Error> {
117 let io = self.io.clone();
118 loop {
119 let op_code = self.reader.read_number().await?;
120 let op = TryInto::<Operation>::try_into(op_code);
121 debug!(?op, "Received operation");
122 match op {
123 Ok(operation) => match operation {
125 Operation::IsValidPath => {
126 let path: StorePath<String> = self.reader.read_value().await?;
127 Self::handle(&self.writer, io.is_valid_path(&path)).await?
128 }
129 Operation::SetOptions => {
134 self.client_settings = self.reader.read_value().await?;
135 Self::handle(&self.writer, async { Ok(()) }).await?
136 }
137 Operation::QueryPathInfo => {
138 let path: StorePath<String> = self.reader.read_value().await?;
139 Self::handle(&self.writer, io.query_path_info(&path)).await?
140 }
141 Operation::QueryPathFromHashPart => {
142 let hash: Bytes = self.reader.read_value().await?;
143 Self::handle(&self.writer, io.query_path_from_hash_part(&hash)).await?
144 }
145 Operation::QueryValidPaths => {
146 let query: QueryValidPaths = self.reader.read_value().await?;
147 Self::handle(&self.writer, io.query_valid_paths(&query)).await?
148 }
149 Operation::QueryValidDerivers => {
150 let path: StorePath<String> = self.reader.read_value().await?;
151 Self::handle(&self.writer, io.query_valid_derivers(&path)).await?
152 }
153 Operation::QueryReferrers | Operation::QueryRealisation => {
160 let _: String = self.reader.read_value().await?;
161 Self::handle(&self.writer, async move {
162 warn!(
163 ?operation,
164 "This operation is not implemented. Returning empty result..."
165 );
166 Ok(Vec::<StorePath<String>>::new())
167 })
168 .await?
169 }
170 Operation::AddToStoreNar => {
171 let request: AddToStoreNarRequest = self.reader.read_value().await?;
172 let minor_version = self.protocol_version.minor();
173 match minor_version {
174 ..21 => {
175 Self::handle(
178 &self.writer,
179 self.io.add_to_store_nar(request, &mut self.reader),
180 )
181 .await?
182 }
183 21..23 => {
184 Self::handle(&self.writer, async {
186 let mut writer = self.writer.lock().await;
187 let mut reader = StderrReadFramedReader::new(
188 &mut self.reader,
189 writer.deref_mut(),
190 );
191 self.io.add_to_store_nar(request, &mut reader).await
192 })
194 .await?
195 }
196 23.. => {
197 let mut framed = NixFramedReader::new(&mut self.reader);
199
200 Self::handle(&self.writer, async {
201 self.io.add_to_store_nar(request, &mut framed).await
202 })
203 .await?;
204
205 if !framed.is_eof() {
208 return Err(std::io::Error::new(
209 std::io::ErrorKind::InvalidData,
210 "payload was not fully consumed",
211 ));
212 }
213 }
214 }
215 }
216 _ => {
217 return Err(std::io::Error::other(format!(
218 "Operation {operation:?} is not implemented"
219 )));
220 }
221 },
222 _ => {
223 return Err(std::io::Error::other(format!(
224 "Unknown operation code received: {op_code}"
225 )));
226 }
227 }
228 }
229 }
230
231 async fn handle<T>(
241 writer: &Arc<Mutex<NixWriter<WriteHalf<RW>>>>,
242 future: impl Future<Output = std::io::Result<T>>,
243 ) -> Result<(), std::io::Error>
244 where
245 T: NixSerialize + Send,
246 {
247 let result = future.await;
248 let mut writer = writer.lock().await;
249
250 match result {
251 Ok(r) => {
252 writer.write_number(STDERR_LAST).await?;
255 writer.write_value(&r).await?;
256 writer.flush().await
257 }
258 Err(e) => {
259 debug!(err = ?e, "IO error");
260 writer.write_number(STDERR_ERROR).await?;
261 writer.write_value(&NixError::new(format!("{e:?}"))).await?;
262 writer.flush().await
263 }
264 }
265 }
266}
267
268#[cfg(test)]
269mod tests {
270 use super::*;
271 use std::{io::ErrorKind, sync::Arc};
272
273 use mockall::predicate;
274 use tokio::io::AsyncWriteExt;
275
276 use crate::{
277 nix_daemon::MockNixDaemonIO,
278 wire::ProtocolVersion,
279 worker_protocol::{ClientSettings, WORKER_MAGIC_1, WORKER_MAGIC_2},
280 };
281
282 #[tokio::test]
283 async fn test_daemon_initialization() {
284 let mut builder = tokio_test::io::Builder::new();
285 let test_conn = builder
286 .read(&WORKER_MAGIC_1.to_le_bytes())
287 .write(&WORKER_MAGIC_2.to_le_bytes())
288 .write(&[37, 1, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00])
290 .read(&[35, 1, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00])
292 .read(&[0; 8])
294 .read(&[0; 8])
296 .write(&[0x06, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00])
298 .write(&[50, 46, 49, 56, 46, 50, 0, 0])
300 .write(&[1, 0, 0, 0, 0, 0, 0, 0])
302 .write(&[115, 116, 108, 97, 0, 0, 0, 0]);
304
305 let mut bytes = Vec::new();
306 let mut writer = NixWriter::new(&mut bytes);
307 writer
308 .write_value(&ClientSettings::default())
309 .await
310 .unwrap();
311 writer.flush().await.unwrap();
312
313 let test_conn = test_conn
314 .read(&[19, 0, 0, 0, 0, 0, 0, 0])
316 .read(&bytes)
317 .write(&[115, 116, 108, 97, 0, 0, 0, 0])
319 .build();
320
321 let mock = MockNixDaemonIO::new();
322 let daemon = NixDaemon::initialize(Arc::new(mock), test_conn)
323 .await
324 .unwrap();
325 assert_eq!(daemon.client_settings, ClientSettings::default());
326 assert_eq!(daemon.protocol_version, ProtocolVersion::from_parts(1, 35));
327 }
328
329 async fn serialize<T>(req: &T, protocol_version: ProtocolVersion) -> Vec<u8>
330 where
331 T: NixSerialize + Send,
332 {
333 let mut result: Vec<u8> = Vec::new();
334 let mut w = NixWriter::builder()
335 .set_version(protocol_version)
336 .build(&mut result);
337 w.write_value(req).await.unwrap();
338 w.flush().await.unwrap();
339 result
340 }
341
342 async fn respond<T>(
343 resp: &Result<T, std::io::Error>,
344 protocol_version: ProtocolVersion,
345 ) -> Vec<u8>
346 where
347 T: NixSerialize + Send,
348 {
349 let mut result: Vec<u8> = Vec::new();
350 let mut w = NixWriter::builder()
351 .set_version(protocol_version)
352 .build(&mut result);
353 match resp {
354 Ok(value) => {
355 w.write_value(&STDERR_LAST).await.unwrap();
356 w.write_value(value).await.unwrap();
357 }
358 Err(e) => {
359 w.write_value(&STDERR_ERROR).await.unwrap();
360 w.write_value(&NixError::new(format!("{:?}", e)))
361 .await
362 .unwrap();
363 }
364 }
365 w.flush().await.unwrap();
366 result
367 }
368
369 #[tokio::test]
370 async fn test_handle_is_valid_path_ok() {
371 let version = ProtocolVersion::from_parts(1, 37);
372 let (io, mut handle) = tokio_test::io::Builder::new().build_with_handle();
373 let mut mock = MockNixDaemonIO::new();
374 let (reader, writer) = split(io);
375 let path: StorePath<String> = StorePath::<String>::from_absolute_path(
376 "/nix/store/33l4p0pn0mybmqzaxfkpppyh7vx1c74p-hello-2.12.1".as_bytes(),
377 )
378 .unwrap();
379 mock.expect_is_valid_path()
380 .with(predicate::eq(path.clone()))
381 .times(1)
382 .returning(|_| Box::pin(async { Ok(true) }));
383
384 handle.read(&Into::<u64>::into(Operation::IsValidPath).to_le_bytes());
385 handle.read(&serialize(&path, version).await);
386 handle.write(&respond(&Ok(true), version).await);
387 drop(handle);
388
389 let mut daemon = NixDaemon::new(
390 Arc::new(mock),
391 version,
392 ClientSettings::default(),
393 NixReader::new(reader),
394 NixWriter::new(writer),
395 );
396 assert_eq!(
397 ErrorKind::UnexpectedEof,
398 daemon
399 .handle_client()
400 .await
401 .expect_err("Expecting eof")
402 .kind()
403 );
404 }
405
406 #[tokio::test]
407 async fn test_handle_is_valid_path_err() {
408 let version = ProtocolVersion::from_parts(1, 37);
409 let (io, mut handle) = tokio_test::io::Builder::new().build_with_handle();
410 let mut mock = MockNixDaemonIO::new();
411 let (reader, writer) = split(io);
412 let path: StorePath<String> = StorePath::<String>::from_absolute_path(
413 "/nix/store/33l4p0pn0mybmqzaxfkpppyh7vx1c74p-hello-2.12.1".as_bytes(),
414 )
415 .unwrap();
416 mock.expect_is_valid_path()
417 .with(predicate::eq(path.clone()))
418 .times(1)
419 .returning(|_| Box::pin(async { Err(std::io::Error::other("hello")) }));
420
421 handle.read(&Into::<u64>::into(Operation::IsValidPath).to_le_bytes());
422 handle.read(&serialize(&path, version).await);
423 handle.write(&respond::<bool>(&Err(std::io::Error::other("hello")), version).await);
424 drop(handle);
425
426 let mut daemon = NixDaemon::new(
427 Arc::new(mock),
428 version,
429 ClientSettings::default(),
430 NixReader::new(reader),
431 NixWriter::new(writer),
432 );
433 assert_eq!(
434 ErrorKind::UnexpectedEof,
435 daemon
436 .handle_client()
437 .await
438 .expect_err("Expecting eof")
439 .kind()
440 );
441 }
442}