nix_compat/wire/de/
mock.rs1use std::collections::VecDeque;
2use std::fmt;
3use std::io;
4use std::thread;
5
6use bytes::Bytes;
7use thiserror::Error;
8
9use crate::wire::ProtocolVersion;
10
11use super::NixRead;
12
13#[derive(Debug, Error, PartialEq, Eq, Clone)]
14pub enum Error {
15 #[error("custom error '{0}'")]
16 Custom(String),
17 #[error("invalid data '{0}'")]
18 InvalidData(String),
19 #[error("missing data '{0}'")]
20 MissingData(String),
21 #[error("IO error {0} '{1}'")]
22 IO(io::ErrorKind, String),
23 #[error("wrong read: expected {0} got {1}")]
24 WrongRead(OperationType, OperationType),
25}
26
27impl Error {
28 pub fn expected_read_number() -> Error {
29 Error::WrongRead(OperationType::ReadNumber, OperationType::ReadBytes)
30 }
31
32 pub fn expected_read_bytes() -> Error {
33 Error::WrongRead(OperationType::ReadBytes, OperationType::ReadNumber)
34 }
35}
36
37impl super::Error for Error {
38 fn custom<T: fmt::Display>(msg: T) -> Self {
39 Self::Custom(msg.to_string())
40 }
41
42 fn io_error(err: std::io::Error) -> Self {
43 Self::IO(err.kind(), err.to_string())
44 }
45
46 fn invalid_data<T: fmt::Display>(msg: T) -> Self {
47 Self::InvalidData(msg.to_string())
48 }
49
50 fn missing_data<T: fmt::Display>(msg: T) -> Self {
51 Self::MissingData(msg.to_string())
52 }
53}
54
55#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
56pub enum OperationType {
57 ReadNumber,
58 ReadBytes,
59}
60
61impl fmt::Display for OperationType {
62 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
63 match self {
64 Self::ReadNumber => write!(f, "read_number"),
65 Self::ReadBytes => write!(f, "read_bytess"),
66 }
67 }
68}
69
70#[derive(Debug, Clone, PartialEq, Eq)]
71enum Operation {
72 ReadNumber(Result<u64, Error>),
73 ReadBytes(Result<Bytes, Error>),
74}
75
76impl From<Operation> for OperationType {
77 fn from(value: Operation) -> Self {
78 match value {
79 Operation::ReadNumber(_) => OperationType::ReadNumber,
80 Operation::ReadBytes(_) => OperationType::ReadBytes,
81 }
82 }
83}
84
85pub struct Builder {
86 version: ProtocolVersion,
87 ops: VecDeque<Operation>,
88}
89
90impl Builder {
91 pub fn new() -> Builder {
92 Builder {
93 version: Default::default(),
94 ops: VecDeque::new(),
95 }
96 }
97
98 pub fn version<V: Into<ProtocolVersion>>(&mut self, version: V) -> &mut Self {
99 self.version = version.into();
100 self
101 }
102
103 pub fn read_number(&mut self, value: u64) -> &mut Self {
104 self.ops.push_back(Operation::ReadNumber(Ok(value)));
105 self
106 }
107
108 pub fn read_number_error(&mut self, err: Error) -> &mut Self {
109 self.ops.push_back(Operation::ReadNumber(Err(err)));
110 self
111 }
112
113 pub fn read_bytes(&mut self, value: Bytes) -> &mut Self {
114 self.ops.push_back(Operation::ReadBytes(Ok(value)));
115 self
116 }
117
118 pub fn read_slice(&mut self, data: &[u8]) -> &mut Self {
119 let value = Bytes::copy_from_slice(data);
120 self.ops.push_back(Operation::ReadBytes(Ok(value)));
121 self
122 }
123
124 pub fn read_bytes_error(&mut self, err: Error) -> &mut Self {
125 self.ops.push_back(Operation::ReadBytes(Err(err)));
126 self
127 }
128
129 pub fn build(&mut self) -> Mock {
130 Mock {
131 version: self.version,
132 ops: self.ops.clone(),
133 }
134 }
135}
136
137impl Default for Builder {
138 fn default() -> Self {
139 Self::new()
140 }
141}
142
143pub struct Mock {
144 version: ProtocolVersion,
145 ops: VecDeque<Operation>,
146}
147
148impl NixRead for Mock {
149 type Error = Error;
150
151 fn version(&self) -> ProtocolVersion {
152 self.version
153 }
154
155 async fn try_read_number(&mut self) -> Result<Option<u64>, Self::Error> {
156 match self.ops.pop_front() {
157 Some(Operation::ReadNumber(ret)) => ret.map(Some),
158 Some(Operation::ReadBytes(_)) => Err(Error::expected_read_bytes()),
159 None => Ok(None),
160 }
161 }
162
163 async fn try_read_bytes_limited(
164 &mut self,
165 _limit: std::ops::RangeInclusive<usize>,
166 ) -> Result<Option<Bytes>, Self::Error> {
167 match self.ops.pop_front() {
168 Some(Operation::ReadBytes(ret)) => ret.map(Some),
169 Some(Operation::ReadNumber(_)) => Err(Error::expected_read_number()),
170 None => Ok(None),
171 }
172 }
173}
174
175impl Drop for Mock {
176 fn drop(&mut self) {
177 if thread::panicking() {
179 return;
180 }
181 if let Some(op) = self.ops.front() {
182 panic!("reader dropped with {op:?} operation still unread")
183 }
184 }
185}
186
187#[cfg(test)]
188mod test {
189 use bytes::Bytes;
190 use hex_literal::hex;
191
192 use crate::wire::de::NixRead;
193
194 use super::{Builder, Error};
195
196 #[tokio::test]
197 async fn read_slice() {
198 let mut mock = Builder::new()
199 .read_number(10)
200 .read_slice(&[])
201 .read_slice(&hex!("0000 1234 5678 9ABC DEFF"))
202 .build();
203 assert_eq!(10, mock.read_number().await.unwrap());
204 assert_eq!(&[] as &[u8], &mock.read_bytes().await.unwrap()[..]);
205 assert_eq!(
206 &hex!("0000 1234 5678 9ABC DEFF"),
207 &mock.read_bytes().await.unwrap()[..]
208 );
209 assert_eq!(None, mock.try_read_number().await.unwrap());
210 assert_eq!(None, mock.try_read_bytes().await.unwrap());
211 }
212
213 #[tokio::test]
214 async fn read_bytes() {
215 let mut mock = Builder::new()
216 .read_number(10)
217 .read_bytes(Bytes::from_static(&[]))
218 .read_bytes(Bytes::from_static(&hex!("0000 1234 5678 9ABC DEFF")))
219 .build();
220 assert_eq!(10, mock.read_number().await.unwrap());
221 assert_eq!(&[] as &[u8], &mock.read_bytes().await.unwrap()[..]);
222 assert_eq!(
223 &hex!("0000 1234 5678 9ABC DEFF"),
224 &mock.read_bytes().await.unwrap()[..]
225 );
226 assert_eq!(None, mock.try_read_number().await.unwrap());
227 assert_eq!(None, mock.try_read_bytes().await.unwrap());
228 }
229
230 #[tokio::test]
231 async fn read_number() {
232 let mut mock = Builder::new().read_number(10).build();
233 assert_eq!(10, mock.read_number().await.unwrap());
234 assert_eq!(None, mock.try_read_number().await.unwrap());
235 assert_eq!(None, mock.try_read_bytes().await.unwrap());
236 }
237
238 #[tokio::test]
239 async fn expect_number() {
240 let mut mock = Builder::new().read_number(10).build();
241 assert_eq!(
242 Error::expected_read_number(),
243 mock.read_bytes().await.unwrap_err()
244 );
245 }
246
247 #[tokio::test]
248 async fn expect_bytes() {
249 let mut mock = Builder::new().read_slice(&[]).build();
250 assert_eq!(
251 Error::expected_read_bytes(),
252 mock.read_number().await.unwrap_err()
253 );
254 }
255
256 #[test]
257 #[should_panic]
258 fn operations_left() {
259 let _ = Builder::new().read_number(10).build();
260 }
261}