nix_compat/wire/ser/
writer.rs

1use std::fmt::{self, Write as _};
2use std::future::poll_fn;
3use std::io::{self, Cursor};
4use std::pin::Pin;
5use std::task::{ready, Context, Poll};
6
7use bytes::{Buf, BufMut, BytesMut};
8use pin_project_lite::pin_project;
9use tokio::io::{AsyncWrite, AsyncWriteExt};
10
11use crate::wire::{padding_len, ProtocolVersion, EMPTY_BYTES};
12
13use super::{Error, NixWrite};
14
15pub struct NixWriterBuilder {
16    buf: Option<BytesMut>,
17    reserved_buf_size: usize,
18    max_buf_size: usize,
19    version: ProtocolVersion,
20}
21
22impl Default for NixWriterBuilder {
23    fn default() -> Self {
24        Self {
25            buf: Default::default(),
26            reserved_buf_size: 8192,
27            max_buf_size: 8192,
28            version: Default::default(),
29        }
30    }
31}
32
33impl NixWriterBuilder {
34    pub fn set_buffer(mut self, buf: BytesMut) -> Self {
35        self.buf = Some(buf);
36        self
37    }
38
39    pub fn set_reserved_buf_size(mut self, size: usize) -> Self {
40        self.reserved_buf_size = size;
41        self
42    }
43
44    pub fn set_max_buf_size(mut self, size: usize) -> Self {
45        self.max_buf_size = size;
46        self
47    }
48
49    pub fn set_version(mut self, version: ProtocolVersion) -> Self {
50        self.version = version;
51        self
52    }
53
54    pub fn build<W>(self, writer: W) -> NixWriter<W> {
55        let buf = self
56            .buf
57            .unwrap_or_else(|| BytesMut::with_capacity(self.max_buf_size));
58        NixWriter {
59            buf,
60            inner: writer,
61            reserved_buf_size: self.reserved_buf_size,
62            max_buf_size: self.max_buf_size,
63            version: self.version,
64        }
65    }
66}
67
68pin_project! {
69    pub struct NixWriter<W> {
70        #[pin]
71        inner: W,
72        buf: BytesMut,
73        reserved_buf_size: usize,
74        max_buf_size: usize,
75        version: ProtocolVersion,
76    }
77}
78
79impl NixWriter<Cursor<Vec<u8>>> {
80    pub fn builder() -> NixWriterBuilder {
81        NixWriterBuilder::default()
82    }
83}
84
85impl<W> NixWriter<W>
86where
87    W: AsyncWriteExt,
88{
89    pub fn new(writer: W) -> NixWriter<W> {
90        NixWriter::builder().build(writer)
91    }
92
93    pub fn buffer(&self) -> &[u8] {
94        &self.buf[..]
95    }
96
97    pub fn set_version(&mut self, version: ProtocolVersion) {
98        self.version = version;
99    }
100
101    /// Remaining capacity in internal buffer
102    pub fn remaining_mut(&self) -> usize {
103        self.buf.capacity() - self.buf.len()
104    }
105
106    fn poll_flush_buf(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
107        let mut this = self.project();
108        while !this.buf.is_empty() {
109            let n = ready!(this.inner.as_mut().poll_write(cx, &this.buf[..]))?;
110            if n == 0 {
111                return Poll::Ready(Err(io::Error::new(
112                    io::ErrorKind::WriteZero,
113                    "failed to write the buffer",
114                )));
115            }
116            this.buf.advance(n);
117        }
118        Poll::Ready(Ok(()))
119    }
120}
121
122impl<W> NixWriter<W>
123where
124    W: AsyncWriteExt + Unpin,
125{
126    async fn flush_buf(&mut self) -> Result<(), io::Error> {
127        let mut s = Pin::new(self);
128        poll_fn(move |cx| s.as_mut().poll_flush_buf(cx)).await
129    }
130}
131
132impl<W> AsyncWrite for NixWriter<W>
133where
134    W: AsyncWrite,
135{
136    fn poll_write(
137        mut self: Pin<&mut Self>,
138        cx: &mut Context<'_>,
139        buf: &[u8],
140    ) -> Poll<Result<usize, io::Error>> {
141        // Flush
142        if self.remaining_mut() < buf.len() {
143            ready!(self.as_mut().poll_flush_buf(cx))?;
144        }
145        let this = self.project();
146        if buf.len() > this.buf.capacity() {
147            this.inner.poll_write(cx, buf)
148        } else {
149            this.buf.put_slice(buf);
150            Poll::Ready(Ok(buf.len()))
151        }
152    }
153
154    fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
155        ready!(self.as_mut().poll_flush_buf(cx))?;
156        self.project().inner.poll_flush(cx)
157    }
158
159    fn poll_shutdown(
160        mut self: Pin<&mut Self>,
161        cx: &mut Context<'_>,
162    ) -> Poll<Result<(), io::Error>> {
163        ready!(self.as_mut().poll_flush_buf(cx))?;
164        self.project().inner.poll_shutdown(cx)
165    }
166}
167
168impl<W> NixWrite for NixWriter<W>
169where
170    W: AsyncWrite + Send + Unpin,
171{
172    type Error = io::Error;
173
174    fn version(&self) -> ProtocolVersion {
175        self.version
176    }
177
178    async fn write_number(&mut self, value: u64) -> Result<(), Self::Error> {
179        let mut buf = [0u8; 8];
180        BufMut::put_u64_le(&mut &mut buf[..], value);
181        self.write_all(&buf).await
182    }
183
184    async fn write_slice(&mut self, buf: &[u8]) -> Result<(), Self::Error> {
185        let padding = padding_len(buf.len() as u64) as usize;
186        self.write_value(&buf.len()).await?;
187        self.write_all(buf).await?;
188        if padding > 0 {
189            self.write_all(&EMPTY_BYTES[..padding]).await
190        } else {
191            Ok(())
192        }
193    }
194
195    async fn write_display<D>(&mut self, msg: D) -> Result<(), Self::Error>
196    where
197        D: fmt::Display + Send,
198        Self: Sized,
199    {
200        // Ensure that buffer has space for at least reserved_buf_size bytes
201        if self.remaining_mut() < self.reserved_buf_size && !self.buf.is_empty() {
202            self.flush_buf().await?;
203        }
204        let offset = self.buf.len();
205        self.buf.put_u64_le(0);
206        if let Err(err) = write!(self.buf, "{}", msg) {
207            self.buf.truncate(offset);
208            return Err(Self::Error::unsupported_data(err));
209        }
210        let len = self.buf.len() - offset - 8;
211        BufMut::put_u64_le(&mut &mut self.buf[offset..(offset + 8)], len as u64);
212        let padding = padding_len(len as u64) as usize;
213        self.write_all(&EMPTY_BYTES[..padding]).await
214    }
215}
216
217#[cfg(test)]
218mod test {
219    use std::time::Duration;
220
221    use hex_literal::hex;
222    use rstest::rstest;
223    use tokio::io::AsyncWriteExt as _;
224    use tokio_test::io::Builder;
225
226    use crate::wire::ser::NixWrite;
227
228    use super::NixWriter;
229
230    #[rstest]
231    #[case(1, &hex!("0100 0000 0000 0000"))]
232    #[case::evil(666, &hex!("9A02 0000 0000 0000"))]
233    #[case::max(u64::MAX, &hex!("FFFF FFFF FFFF FFFF"))]
234    #[tokio::test]
235    async fn test_write_number(#[case] number: u64, #[case] buf: &[u8]) {
236        let mock = Builder::new().write(buf).build();
237        let mut writer = NixWriter::new(mock);
238
239        writer.write_number(number).await.unwrap();
240        assert_eq!(writer.buffer(), buf);
241        writer.flush().await.unwrap();
242        assert_eq!(writer.buffer(), b"");
243    }
244
245    #[rstest]
246    #[case::empty(b"", &hex!("0000 0000 0000 0000"))]
247    #[case::one(b")", &hex!("0100 0000 0000 0000 2900 0000 0000 0000"))]
248    #[case::two(b"it", &hex!("0200 0000 0000 0000 6974 0000 0000 0000"))]
249    #[case::three(b"tea", &hex!("0300 0000 0000 0000 7465 6100 0000 0000"))]
250    #[case::four(b"were", &hex!("0400 0000 0000 0000 7765 7265 0000 0000"))]
251    #[case::five(b"where", &hex!("0500 0000 0000 0000 7768 6572 6500 0000"))]
252    #[case::six(b"unwrap", &hex!("0600 0000 0000 0000 756E 7772 6170 0000"))]
253    #[case::seven(b"where's", &hex!("0700 0000 0000 0000 7768 6572 6527 7300"))]
254    #[case::aligned(b"read_tea", &hex!("0800 0000 0000 0000 7265 6164 5F74 6561"))]
255    #[case::more_bytes(b"read_tess", &hex!("0900 0000 0000 0000 7265 6164 5F74 6573 7300 0000 0000 0000"))]
256    #[tokio::test]
257    async fn test_write_slice(
258        #[case] value: &[u8],
259        #[case] buf: &[u8],
260        #[values(1, 2, 3, 4, 5, 6, 7, 8, 9, 1024)] chunks_size: usize,
261        #[values(1, 2, 3, 4, 5, 6, 7, 8, 9, 1024)] buf_size: usize,
262    ) {
263        let mut builder = Builder::new();
264        for chunk in buf.chunks(chunks_size) {
265            builder.write(chunk);
266            builder.wait(Duration::ZERO);
267        }
268        let mock = builder.build();
269        let mut writer = NixWriter::builder().set_max_buf_size(buf_size).build(mock);
270
271        writer.write_slice(value).await.unwrap();
272        writer.flush().await.unwrap();
273        assert_eq!(writer.buffer(), b"");
274    }
275
276    #[rstest]
277    #[case::empty("", &hex!("0000 0000 0000 0000"))]
278    #[case::one(")", &hex!("0100 0000 0000 0000 2900 0000 0000 0000"))]
279    #[case::two("it", &hex!("0200 0000 0000 0000 6974 0000 0000 0000"))]
280    #[case::three("tea", &hex!("0300 0000 0000 0000 7465 6100 0000 0000"))]
281    #[case::four("were", &hex!("0400 0000 0000 0000 7765 7265 0000 0000"))]
282    #[case::five("where", &hex!("0500 0000 0000 0000 7768 6572 6500 0000"))]
283    #[case::six("unwrap", &hex!("0600 0000 0000 0000 756E 7772 6170 0000"))]
284    #[case::seven("where's", &hex!("0700 0000 0000 0000 7768 6572 6527 7300"))]
285    #[case::aligned("read_tea", &hex!("0800 0000 0000 0000 7265 6164 5F74 6561"))]
286    #[case::more_bytes("read_tess", &hex!("0900 0000 0000 0000 7265 6164 5F74 6573 7300 0000 0000 0000"))]
287    #[tokio::test]
288    async fn test_write_display(
289        #[case] value: &str,
290        #[case] buf: &[u8],
291        #[values(1, 2, 3, 4, 5, 6, 7, 8, 9, 1024)] chunks_size: usize,
292    ) {
293        let mut builder = Builder::new();
294        for chunk in buf.chunks(chunks_size) {
295            builder.write(chunk);
296            builder.wait(Duration::ZERO);
297        }
298        let mock = builder.build();
299        let mut writer = NixWriter::builder().build(mock);
300
301        writer.write_display(value).await.unwrap();
302        assert_eq!(writer.buffer(), buf);
303        writer.flush().await.unwrap();
304        assert_eq!(writer.buffer(), b"");
305    }
306}