1use std::fmt::{self, Write as _};
2use std::future::poll_fn;
3use std::io::{self, Cursor};
4use std::pin::Pin;
5use std::task::{ready, Context, Poll};
6
7use bytes::{Buf, BufMut, BytesMut};
8use pin_project_lite::pin_project;
9use tokio::io::{AsyncWrite, AsyncWriteExt};
10
11use crate::wire::{padding_len, ProtocolVersion, EMPTY_BYTES};
12
13use super::{Error, NixWrite};
14
15pub struct NixWriterBuilder {
16 buf: Option<BytesMut>,
17 reserved_buf_size: usize,
18 max_buf_size: usize,
19 version: ProtocolVersion,
20}
21
22impl Default for NixWriterBuilder {
23 fn default() -> Self {
24 Self {
25 buf: Default::default(),
26 reserved_buf_size: 8192,
27 max_buf_size: 8192,
28 version: Default::default(),
29 }
30 }
31}
32
33impl NixWriterBuilder {
34 pub fn set_buffer(mut self, buf: BytesMut) -> Self {
35 self.buf = Some(buf);
36 self
37 }
38
39 pub fn set_reserved_buf_size(mut self, size: usize) -> Self {
40 self.reserved_buf_size = size;
41 self
42 }
43
44 pub fn set_max_buf_size(mut self, size: usize) -> Self {
45 self.max_buf_size = size;
46 self
47 }
48
49 pub fn set_version(mut self, version: ProtocolVersion) -> Self {
50 self.version = version;
51 self
52 }
53
54 pub fn build<W>(self, writer: W) -> NixWriter<W> {
55 let buf = self
56 .buf
57 .unwrap_or_else(|| BytesMut::with_capacity(self.max_buf_size));
58 NixWriter {
59 buf,
60 inner: writer,
61 reserved_buf_size: self.reserved_buf_size,
62 max_buf_size: self.max_buf_size,
63 version: self.version,
64 }
65 }
66}
67
68pin_project! {
69 pub struct NixWriter<W> {
70 #[pin]
71 inner: W,
72 buf: BytesMut,
73 reserved_buf_size: usize,
74 max_buf_size: usize,
75 version: ProtocolVersion,
76 }
77}
78
79impl NixWriter<Cursor<Vec<u8>>> {
80 pub fn builder() -> NixWriterBuilder {
81 NixWriterBuilder::default()
82 }
83}
84
85impl<W> NixWriter<W>
86where
87 W: AsyncWriteExt,
88{
89 pub fn new(writer: W) -> NixWriter<W> {
90 NixWriter::builder().build(writer)
91 }
92
93 pub fn buffer(&self) -> &[u8] {
94 &self.buf[..]
95 }
96
97 pub fn set_version(&mut self, version: ProtocolVersion) {
98 self.version = version;
99 }
100
101 pub fn remaining_mut(&self) -> usize {
103 self.buf.capacity() - self.buf.len()
104 }
105
106 fn poll_flush_buf(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
107 let mut this = self.project();
108 while !this.buf.is_empty() {
109 let n = ready!(this.inner.as_mut().poll_write(cx, &this.buf[..]))?;
110 if n == 0 {
111 return Poll::Ready(Err(io::Error::new(
112 io::ErrorKind::WriteZero,
113 "failed to write the buffer",
114 )));
115 }
116 this.buf.advance(n);
117 }
118 Poll::Ready(Ok(()))
119 }
120}
121
122impl<W> NixWriter<W>
123where
124 W: AsyncWriteExt + Unpin,
125{
126 async fn flush_buf(&mut self) -> Result<(), io::Error> {
127 let mut s = Pin::new(self);
128 poll_fn(move |cx| s.as_mut().poll_flush_buf(cx)).await
129 }
130}
131
132impl<W> AsyncWrite for NixWriter<W>
133where
134 W: AsyncWrite,
135{
136 fn poll_write(
137 mut self: Pin<&mut Self>,
138 cx: &mut Context<'_>,
139 buf: &[u8],
140 ) -> Poll<Result<usize, io::Error>> {
141 if self.remaining_mut() < buf.len() {
143 ready!(self.as_mut().poll_flush_buf(cx))?;
144 }
145 let this = self.project();
146 if buf.len() > this.buf.capacity() {
147 this.inner.poll_write(cx, buf)
148 } else {
149 this.buf.put_slice(buf);
150 Poll::Ready(Ok(buf.len()))
151 }
152 }
153
154 fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
155 ready!(self.as_mut().poll_flush_buf(cx))?;
156 self.project().inner.poll_flush(cx)
157 }
158
159 fn poll_shutdown(
160 mut self: Pin<&mut Self>,
161 cx: &mut Context<'_>,
162 ) -> Poll<Result<(), io::Error>> {
163 ready!(self.as_mut().poll_flush_buf(cx))?;
164 self.project().inner.poll_shutdown(cx)
165 }
166}
167
168impl<W> NixWrite for NixWriter<W>
169where
170 W: AsyncWrite + Send + Unpin,
171{
172 type Error = io::Error;
173
174 fn version(&self) -> ProtocolVersion {
175 self.version
176 }
177
178 async fn write_number(&mut self, value: u64) -> Result<(), Self::Error> {
179 let mut buf = [0u8; 8];
180 BufMut::put_u64_le(&mut &mut buf[..], value);
181 self.write_all(&buf).await
182 }
183
184 async fn write_slice(&mut self, buf: &[u8]) -> Result<(), Self::Error> {
185 let padding = padding_len(buf.len() as u64) as usize;
186 self.write_value(&buf.len()).await?;
187 self.write_all(buf).await?;
188 if padding > 0 {
189 self.write_all(&EMPTY_BYTES[..padding]).await
190 } else {
191 Ok(())
192 }
193 }
194
195 async fn write_display<D>(&mut self, msg: D) -> Result<(), Self::Error>
196 where
197 D: fmt::Display + Send,
198 Self: Sized,
199 {
200 if self.remaining_mut() < self.reserved_buf_size && !self.buf.is_empty() {
202 self.flush_buf().await?;
203 }
204 let offset = self.buf.len();
205 self.buf.put_u64_le(0);
206 if let Err(err) = write!(self.buf, "{}", msg) {
207 self.buf.truncate(offset);
208 return Err(Self::Error::unsupported_data(err));
209 }
210 let len = self.buf.len() - offset - 8;
211 BufMut::put_u64_le(&mut &mut self.buf[offset..(offset + 8)], len as u64);
212 let padding = padding_len(len as u64) as usize;
213 self.write_all(&EMPTY_BYTES[..padding]).await
214 }
215}
216
217#[cfg(test)]
218mod test {
219 use std::time::Duration;
220
221 use hex_literal::hex;
222 use rstest::rstest;
223 use tokio::io::AsyncWriteExt as _;
224 use tokio_test::io::Builder;
225
226 use crate::wire::ser::NixWrite;
227
228 use super::NixWriter;
229
230 #[rstest]
231 #[case(1, &hex!("0100 0000 0000 0000"))]
232 #[case::evil(666, &hex!("9A02 0000 0000 0000"))]
233 #[case::max(u64::MAX, &hex!("FFFF FFFF FFFF FFFF"))]
234 #[tokio::test]
235 async fn test_write_number(#[case] number: u64, #[case] buf: &[u8]) {
236 let mock = Builder::new().write(buf).build();
237 let mut writer = NixWriter::new(mock);
238
239 writer.write_number(number).await.unwrap();
240 assert_eq!(writer.buffer(), buf);
241 writer.flush().await.unwrap();
242 assert_eq!(writer.buffer(), b"");
243 }
244
245 #[rstest]
246 #[case::empty(b"", &hex!("0000 0000 0000 0000"))]
247 #[case::one(b")", &hex!("0100 0000 0000 0000 2900 0000 0000 0000"))]
248 #[case::two(b"it", &hex!("0200 0000 0000 0000 6974 0000 0000 0000"))]
249 #[case::three(b"tea", &hex!("0300 0000 0000 0000 7465 6100 0000 0000"))]
250 #[case::four(b"were", &hex!("0400 0000 0000 0000 7765 7265 0000 0000"))]
251 #[case::five(b"where", &hex!("0500 0000 0000 0000 7768 6572 6500 0000"))]
252 #[case::six(b"unwrap", &hex!("0600 0000 0000 0000 756E 7772 6170 0000"))]
253 #[case::seven(b"where's", &hex!("0700 0000 0000 0000 7768 6572 6527 7300"))]
254 #[case::aligned(b"read_tea", &hex!("0800 0000 0000 0000 7265 6164 5F74 6561"))]
255 #[case::more_bytes(b"read_tess", &hex!("0900 0000 0000 0000 7265 6164 5F74 6573 7300 0000 0000 0000"))]
256 #[tokio::test]
257 async fn test_write_slice(
258 #[case] value: &[u8],
259 #[case] buf: &[u8],
260 #[values(1, 2, 3, 4, 5, 6, 7, 8, 9, 1024)] chunks_size: usize,
261 #[values(1, 2, 3, 4, 5, 6, 7, 8, 9, 1024)] buf_size: usize,
262 ) {
263 let mut builder = Builder::new();
264 for chunk in buf.chunks(chunks_size) {
265 builder.write(chunk);
266 builder.wait(Duration::ZERO);
267 }
268 let mock = builder.build();
269 let mut writer = NixWriter::builder().set_max_buf_size(buf_size).build(mock);
270
271 writer.write_slice(value).await.unwrap();
272 writer.flush().await.unwrap();
273 assert_eq!(writer.buffer(), b"");
274 }
275
276 #[rstest]
277 #[case::empty("", &hex!("0000 0000 0000 0000"))]
278 #[case::one(")", &hex!("0100 0000 0000 0000 2900 0000 0000 0000"))]
279 #[case::two("it", &hex!("0200 0000 0000 0000 6974 0000 0000 0000"))]
280 #[case::three("tea", &hex!("0300 0000 0000 0000 7465 6100 0000 0000"))]
281 #[case::four("were", &hex!("0400 0000 0000 0000 7765 7265 0000 0000"))]
282 #[case::five("where", &hex!("0500 0000 0000 0000 7768 6572 6500 0000"))]
283 #[case::six("unwrap", &hex!("0600 0000 0000 0000 756E 7772 6170 0000"))]
284 #[case::seven("where's", &hex!("0700 0000 0000 0000 7768 6572 6527 7300"))]
285 #[case::aligned("read_tea", &hex!("0800 0000 0000 0000 7265 6164 5F74 6561"))]
286 #[case::more_bytes("read_tess", &hex!("0900 0000 0000 0000 7265 6164 5F74 6573 7300 0000 0000 0000"))]
287 #[tokio::test]
288 async fn test_write_display(
289 #[case] value: &str,
290 #[case] buf: &[u8],
291 #[values(1, 2, 3, 4, 5, 6, 7, 8, 9, 1024)] chunks_size: usize,
292 ) {
293 let mut builder = Builder::new();
294 for chunk in buf.chunks(chunks_size) {
295 builder.write(chunk);
296 builder.wait(Duration::ZERO);
297 }
298 let mock = builder.build();
299 let mut writer = NixWriter::builder().build(mock);
300
301 writer.write_display(value).await.unwrap();
302 assert_eq!(writer.buffer(), buf);
303 writer.flush().await.unwrap();
304 assert_eq!(writer.buffer(), b"");
305 }
306}