1use std::{
2    cmp::min,
3    collections::BTreeMap,
4    io::{Error, ErrorKind},
5};
6
7use nix_compat_derive::{NixDeserialize, NixSerialize};
8use num_enum::{IntoPrimitive, TryFromPrimitive};
9use tokio::io::{AsyncReadExt, AsyncWriteExt};
10
11use crate::{log::VerbosityLevel, wire};
12
13use crate::wire::ProtocolVersion;
14
15pub(crate) static WORKER_MAGIC_1: u64 = 0x6e697863; pub(crate) static WORKER_MAGIC_2: u64 = 0x6478696f; pub static STDERR_LAST: u64 = 0x616c7473; pub(crate) static STDERR_ERROR: u64 = 0x63787470; pub(crate) static STDERR_READ: u64 = 0x64617461; static PROTOCOL_VERSION: ProtocolVersion = ProtocolVersion::from_parts(1, 37);
43
44pub static MAX_SETTING_SIZE: usize = 1024;
49
50#[derive(
61    Clone, Debug, PartialEq, TryFromPrimitive, IntoPrimitive, NixDeserialize, NixSerialize,
62)]
63#[nix(try_from = "u64", into = "u64")]
64#[repr(u64)]
65pub enum Operation {
66    IsValidPath = 1,
67    HasSubstitutes = 3,
68    QueryPathHash = 4,   QueryReferences = 5, QueryReferrers = 6,
71    AddToStore = 7,
72    AddTextToStore = 8, BuildPaths = 9,
74    EnsurePath = 10,
75    AddTempRoot = 11,
76    AddIndirectRoot = 12,
77    SyncWithGC = 13,
78    FindRoots = 14,
79    ExportPath = 16,   QueryDeriver = 18, SetOptions = 19,
82    CollectGarbage = 20,
83    QuerySubstitutablePathInfo = 21,
84    QueryDerivationOutputs = 22, QueryAllValidPaths = 23,
86    QueryFailedPaths = 24,
87    ClearFailedPaths = 25,
88    QueryPathInfo = 26,
89    ImportPaths = 27,                QueryDerivationOutputNames = 28, QueryPathFromHashPart = 29,
92    QuerySubstitutablePathInfos = 30,
93    QueryValidPaths = 31,
94    QuerySubstitutablePaths = 32,
95    QueryValidDerivers = 33,
96    OptimiseStore = 34,
97    VerifyStore = 35,
98    BuildDerivation = 36,
99    AddSignatures = 37,
100    NarFromPath = 38,
101    AddToStoreNar = 39,
102    QueryMissing = 40,
103    QueryDerivationOutputMap = 41,
104    RegisterDrvOutput = 42,
105    QueryRealisation = 43,
106    AddMultipleToStore = 44,
107    AddBuildLog = 45,
108    BuildPathsWithResults = 46,
109    AddPermRoot = 47,
110}
111
112#[derive(Debug, PartialEq, NixDeserialize, NixSerialize, Default)]
115pub struct ClientSettings {
116    pub keep_failed: bool,
117    pub keep_going: bool,
118    pub try_fallback: bool,
119    pub verbosity: VerbosityLevel,
123    pub max_build_jobs: u64,
124    pub max_silent_time: u64,
125    pub use_build_hook: bool,
126    pub verbose_build: u64,
127    pub log_type: u64,
128    pub print_build_trace: u64,
129    pub build_cores: u64,
130    pub use_substitutes: bool,
131
132    #[nix(version = "12..")]
139    pub overrides: BTreeMap<String, String>,
140}
141
142pub async fn server_handshake_client<'a, RW: 'a>(
160    mut conn: &'a mut RW,
161    nix_version: &str,
162    trusted: Trust,
163) -> std::io::Result<ProtocolVersion>
164where
165    &'a mut RW: AsyncReadExt + AsyncWriteExt + Unpin,
166{
167    let worker_magic_1 = conn.read_u64_le().await?;
168    if worker_magic_1 != WORKER_MAGIC_1 {
169        Err(std::io::Error::new(
170            ErrorKind::InvalidData,
171            format!("Incorrect worker magic number received: {worker_magic_1}"),
172        ))
173    } else {
174        conn.write_u64_le(WORKER_MAGIC_2).await?;
175        conn.write_u64_le(PROTOCOL_VERSION.into()).await?;
176        conn.flush().await?;
177        let client_version = conn.read_u64_le().await?;
178        let client_version: ProtocolVersion = client_version
180            .try_into()
181            .map_err(|e| Error::new(ErrorKind::Unsupported, e))?;
182        if client_version < ProtocolVersion::from_parts(1, 10) {
183            return Err(Error::new(
184                ErrorKind::Unsupported,
185                format!("The nix client version {client_version} is too old"),
186            ));
187        }
188        let picked_version = min(PROTOCOL_VERSION, client_version);
189        if picked_version.minor() >= 14 {
190            let read_affinity = conn.read_u64_le().await?;
192            if read_affinity != 0 {
193                let _cpu_affinity = conn.read_u64_le().await?;
194            };
195        }
196        if picked_version.minor() >= 11 {
197            let _reserve_space = conn.read_u64_le().await?;
199        }
200        if picked_version.minor() >= 33 {
201            wire::write_bytes(&mut conn, nix_version).await?;
203            conn.flush().await?;
204        }
205        if picked_version.minor() >= 35 {
206            write_worker_trust_level(&mut conn, trusted).await?;
207        }
208        Ok(picked_version)
209    }
210}
211
212pub async fn read_op<R: AsyncReadExt + Unpin>(r: &mut R) -> std::io::Result<Operation> {
214    let op_number = r.read_u64_le().await?;
215    Operation::try_from(op_number).map_err(|_| {
216        Error::new(
217            ErrorKind::InvalidData,
218            format!("Invalid OP number {op_number}"),
219        )
220    })
221}
222
223pub async fn write_op<W: AsyncWriteExt + Unpin>(w: &mut W, op: Operation) -> std::io::Result<()> {
225    let op: u64 = op.into();
226    w.write_u64(op).await
227}
228
229#[derive(Debug, PartialEq)]
230pub enum Trust {
231    Trusted,
232    NotTrusted,
233}
234
235pub async fn write_worker_trust_level<W>(conn: &mut W, t: Trust) -> std::io::Result<()>
242where
243    W: AsyncReadExt + AsyncWriteExt + Unpin,
244{
245    match t {
246        Trust::Trusted => conn.write_u64_le(1).await,
247        Trust::NotTrusted => conn.write_u64_le(2).await,
248    }
249}
250
251#[cfg(test)]
252mod tests {
253    use super::*;
254
255    #[tokio::test]
256    async fn test_init_hanshake() {
257        let mut test_conn = tokio_test::io::Builder::new()
258            .read(&WORKER_MAGIC_1.to_le_bytes())
259            .write(&WORKER_MAGIC_2.to_le_bytes())
260            .write(&[37, 1, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00])
261            .read(&[37, 1, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00])
264            .read(&[0; 8])
266            .read(&[0; 8])
268            .write(&[0x06, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00])
270            .write(&[50, 46, 49, 56, 46, 50, 0, 0])
272            .write(&[1, 0, 0, 0, 0, 0, 0, 0])
274            .build();
275        let picked_version = server_handshake_client(&mut test_conn, "2.18.2", Trust::Trusted)
276            .await
277            .unwrap();
278
279        assert_eq!(picked_version, PROTOCOL_VERSION)
280    }
281
282    #[tokio::test]
283    async fn test_init_hanshake_with_newer_client_should_use_older_version() {
284        let mut test_conn = tokio_test::io::Builder::new()
285            .read(&WORKER_MAGIC_1.to_le_bytes())
286            .write(&WORKER_MAGIC_2.to_le_bytes())
287            .write(&[37, 1, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00])
288            .read(&[38, 1, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00])
290            .read(&[0; 8])
292            .read(&[0; 8])
294            .write(&[0x06, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00])
296            .write(&[50, 46, 49, 56, 46, 50, 0, 0])
298            .write(&[1, 0, 0, 0, 0, 0, 0, 0])
300            .build();
301        let picked_version = server_handshake_client(&mut test_conn, "2.18.2", Trust::Trusted)
302            .await
303            .unwrap();
304
305        assert_eq!(picked_version, PROTOCOL_VERSION)
306    }
307
308    #[tokio::test]
309    async fn test_init_hanshake_with_older_client_should_use_older_version() {
310        let mut test_conn = tokio_test::io::Builder::new()
311            .read(&WORKER_MAGIC_1.to_le_bytes())
312            .write(&WORKER_MAGIC_2.to_le_bytes())
313            .write(&[37, 1, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00])
314            .read(&[24, 1, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00])
316            .read(&[0; 8])
318            .read(&[0; 8])
320            .build();
328        let picked_version = server_handshake_client(&mut test_conn, "2.18.2", Trust::Trusted)
329            .await
330            .unwrap();
331
332        assert_eq!(picked_version, ProtocolVersion::from_parts(1, 24))
333    }
334}