nix_compat/wire/de/
mock.rs

1use std::collections::VecDeque;
2use std::fmt;
3use std::io;
4use std::thread;
5
6use bytes::Bytes;
7use thiserror::Error;
8
9use crate::wire::ProtocolVersion;
10
11use super::NixRead;
12
13#[derive(Debug, Error, PartialEq, Eq, Clone)]
14pub enum Error {
15    #[error("custom error '{0}'")]
16    Custom(String),
17    #[error("invalid data '{0}'")]
18    InvalidData(String),
19    #[error("missing data '{0}'")]
20    MissingData(String),
21    #[error("IO error {0} '{1}'")]
22    IO(io::ErrorKind, String),
23    #[error("wrong read: expected {0} got {1}")]
24    WrongRead(OperationType, OperationType),
25}
26
27impl Error {
28    pub fn expected_read_number() -> Error {
29        Error::WrongRead(OperationType::ReadNumber, OperationType::ReadBytes)
30    }
31
32    pub fn expected_read_bytes() -> Error {
33        Error::WrongRead(OperationType::ReadBytes, OperationType::ReadNumber)
34    }
35}
36
37impl super::Error for Error {
38    fn custom<T: fmt::Display>(msg: T) -> Self {
39        Self::Custom(msg.to_string())
40    }
41
42    fn io_error(err: std::io::Error) -> Self {
43        Self::IO(err.kind(), err.to_string())
44    }
45
46    fn invalid_data<T: fmt::Display>(msg: T) -> Self {
47        Self::InvalidData(msg.to_string())
48    }
49
50    fn missing_data<T: fmt::Display>(msg: T) -> Self {
51        Self::MissingData(msg.to_string())
52    }
53}
54
55#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
56pub enum OperationType {
57    ReadNumber,
58    ReadBytes,
59}
60
61impl fmt::Display for OperationType {
62    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
63        match self {
64            Self::ReadNumber => write!(f, "read_number"),
65            Self::ReadBytes => write!(f, "read_bytess"),
66        }
67    }
68}
69
70#[derive(Debug, Clone, PartialEq, Eq)]
71enum Operation {
72    ReadNumber(Result<u64, Error>),
73    ReadBytes(Result<Bytes, Error>),
74}
75
76impl From<Operation> for OperationType {
77    fn from(value: Operation) -> Self {
78        match value {
79            Operation::ReadNumber(_) => OperationType::ReadNumber,
80            Operation::ReadBytes(_) => OperationType::ReadBytes,
81        }
82    }
83}
84
85pub struct Builder {
86    version: ProtocolVersion,
87    ops: VecDeque<Operation>,
88}
89
90impl Builder {
91    pub fn new() -> Builder {
92        Builder {
93            version: Default::default(),
94            ops: VecDeque::new(),
95        }
96    }
97
98    pub fn version<V: Into<ProtocolVersion>>(&mut self, version: V) -> &mut Self {
99        self.version = version.into();
100        self
101    }
102
103    pub fn read_number(&mut self, value: u64) -> &mut Self {
104        self.ops.push_back(Operation::ReadNumber(Ok(value)));
105        self
106    }
107
108    pub fn read_number_error(&mut self, err: Error) -> &mut Self {
109        self.ops.push_back(Operation::ReadNumber(Err(err)));
110        self
111    }
112
113    pub fn read_bytes(&mut self, value: Bytes) -> &mut Self {
114        self.ops.push_back(Operation::ReadBytes(Ok(value)));
115        self
116    }
117
118    pub fn read_slice(&mut self, data: &[u8]) -> &mut Self {
119        let value = Bytes::copy_from_slice(data);
120        self.ops.push_back(Operation::ReadBytes(Ok(value)));
121        self
122    }
123
124    pub fn read_bytes_error(&mut self, err: Error) -> &mut Self {
125        self.ops.push_back(Operation::ReadBytes(Err(err)));
126        self
127    }
128
129    pub fn build(&mut self) -> Mock {
130        Mock {
131            version: self.version,
132            ops: self.ops.clone(),
133        }
134    }
135}
136
137impl Default for Builder {
138    fn default() -> Self {
139        Self::new()
140    }
141}
142
143pub struct Mock {
144    version: ProtocolVersion,
145    ops: VecDeque<Operation>,
146}
147
148impl NixRead for Mock {
149    type Error = Error;
150
151    fn version(&self) -> ProtocolVersion {
152        self.version
153    }
154
155    async fn try_read_number(&mut self) -> Result<Option<u64>, Self::Error> {
156        match self.ops.pop_front() {
157            Some(Operation::ReadNumber(ret)) => ret.map(Some),
158            Some(Operation::ReadBytes(_)) => Err(Error::expected_read_bytes()),
159            None => Ok(None),
160        }
161    }
162
163    async fn try_read_bytes_limited(
164        &mut self,
165        _limit: std::ops::RangeInclusive<usize>,
166    ) -> Result<Option<Bytes>, Self::Error> {
167        match self.ops.pop_front() {
168            Some(Operation::ReadBytes(ret)) => ret.map(Some),
169            Some(Operation::ReadNumber(_)) => Err(Error::expected_read_number()),
170            None => Ok(None),
171        }
172    }
173}
174
175impl Drop for Mock {
176    fn drop(&mut self) {
177        // No need to panic again
178        if thread::panicking() {
179            return;
180        }
181        if let Some(op) = self.ops.front() {
182            panic!("reader dropped with {op:?} operation still unread")
183        }
184    }
185}
186
187#[cfg(test)]
188mod test {
189    use bytes::Bytes;
190    use hex_literal::hex;
191
192    use crate::wire::de::NixRead;
193
194    use super::{Builder, Error};
195
196    #[tokio::test]
197    async fn read_slice() {
198        let mut mock = Builder::new()
199            .read_number(10)
200            .read_slice(&[])
201            .read_slice(&hex!("0000 1234 5678 9ABC DEFF"))
202            .build();
203        assert_eq!(10, mock.read_number().await.unwrap());
204        assert_eq!(&[] as &[u8], &mock.read_bytes().await.unwrap()[..]);
205        assert_eq!(
206            &hex!("0000 1234 5678 9ABC DEFF"),
207            &mock.read_bytes().await.unwrap()[..]
208        );
209        assert_eq!(None, mock.try_read_number().await.unwrap());
210        assert_eq!(None, mock.try_read_bytes().await.unwrap());
211    }
212
213    #[tokio::test]
214    async fn read_bytes() {
215        let mut mock = Builder::new()
216            .read_number(10)
217            .read_bytes(Bytes::from_static(&[]))
218            .read_bytes(Bytes::from_static(&hex!("0000 1234 5678 9ABC DEFF")))
219            .build();
220        assert_eq!(10, mock.read_number().await.unwrap());
221        assert_eq!(&[] as &[u8], &mock.read_bytes().await.unwrap()[..]);
222        assert_eq!(
223            &hex!("0000 1234 5678 9ABC DEFF"),
224            &mock.read_bytes().await.unwrap()[..]
225        );
226        assert_eq!(None, mock.try_read_number().await.unwrap());
227        assert_eq!(None, mock.try_read_bytes().await.unwrap());
228    }
229
230    #[tokio::test]
231    async fn read_number() {
232        let mut mock = Builder::new().read_number(10).build();
233        assert_eq!(10, mock.read_number().await.unwrap());
234        assert_eq!(None, mock.try_read_number().await.unwrap());
235        assert_eq!(None, mock.try_read_bytes().await.unwrap());
236    }
237
238    #[tokio::test]
239    async fn expect_number() {
240        let mut mock = Builder::new().read_number(10).build();
241        assert_eq!(
242            Error::expected_read_number(),
243            mock.read_bytes().await.unwrap_err()
244        );
245    }
246
247    #[tokio::test]
248    async fn expect_bytes() {
249        let mut mock = Builder::new().read_slice(&[]).build();
250        assert_eq!(
251            Error::expected_read_bytes(),
252            mock.read_number().await.unwrap_err()
253        );
254    }
255
256    #[test]
257    #[should_panic]
258    fn operations_left() {
259        let _ = Builder::new().read_number(10).build();
260    }
261}