tokio_listener/
connection.rs

1#[allow(unused_imports)]
2use std::{
3    ffi::c_int,
4    fmt::Display,
5    net::SocketAddr,
6    path::PathBuf,
7    pin::Pin,
8    str::FromStr,
9    sync::Arc,
10    task::{ready, Context, Poll},
11    time::Duration,
12};
13
14use pin_project::pin_project;
15use tokio::{
16    io::{AsyncRead, AsyncWrite, Stdin, Stdout},
17    net::TcpStream,
18    sync::oneshot::Sender,
19};
20use tracing::{debug, warn};
21
22#[cfg(unix)]
23use tokio::net::UnixStream;
24
25/// Accepted connection, which can be a TCP socket, AF_UNIX stream socket or a stdin/stdout pair.
26///
27/// Although inner enum is private, you can use methods or `From` impls to convert this to/from usual Tokio types.
28#[pin_project]
29pub struct Connection(#[pin] pub(crate) ConnectionImpl);
30
31impl std::fmt::Debug for Connection {
32    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
33        match self.0 {
34            ConnectionImpl::Tcp(_) => f.write_str("Connection(tcp)"),
35            #[cfg(all(feature = "unix", unix))]
36            ConnectionImpl::Unix(_) => f.write_str("Connection(unix)"),
37            #[cfg(feature = "inetd")]
38            ConnectionImpl::Stdio(_, _, _) => f.write_str("Connection(stdio)"),
39        }
40    }
41}
42
43#[derive(Debug)]
44#[pin_project(project = ConnectionImplProj)]
45pub(crate) enum ConnectionImpl {
46    Tcp(#[pin] TcpStream),
47    #[cfg(all(feature = "unix", unix))]
48    Unix(#[pin] UnixStream),
49    #[cfg(feature = "inetd")]
50    Stdio(
51        #[pin] tokio::io::Stdin,
52        #[pin] tokio::io::Stdout,
53        Option<Sender<()>>,
54    ),
55}
56
57#[allow(missing_docs)]
58#[allow(clippy::missing_errors_doc)]
59impl Connection {
60    pub fn try_into_tcp(self) -> Result<TcpStream, Self> {
61        if let ConnectionImpl::Tcp(s) = self.0 {
62            Ok(s)
63        } else {
64            Err(self)
65        }
66    }
67    #[cfg(all(feature = "unix", unix))]
68    #[cfg_attr(docsrs_alt, doc(cfg(all(feature = "unix", unix))))]
69    pub fn try_into_unix(self) -> Result<UnixStream, Self> {
70        if let ConnectionImpl::Unix(s) = self.0 {
71            Ok(s)
72        } else {
73            Err(self)
74        }
75    }
76    #[cfg(feature = "inetd")]
77    #[cfg_attr(docsrs_alt, doc(cfg(feature = "inetd")))]
78    /// Get parts of the connection in case of inted mode is used.
79    ///
80    /// Third tuple part (Sender) should be used to signal [`Listener`] to exit from listening loop,
81    /// allowing proper timing of listening termination - without trying to wait for second client in inetd mode,
82    /// but also without exiting prematurely, while the client is still being served, as exiting the listening loop may
83    /// cause the whole process to finish.
84    pub fn try_into_stdio(self) -> Result<(Stdin, Stdout, Option<Sender<()>>), Self> {
85        if let ConnectionImpl::Stdio(i, o, f) = self.0 {
86            Ok((i, o, f))
87        } else {
88            Err(self)
89        }
90    }
91
92    pub fn try_borrow_tcp(&self) -> Option<&TcpStream> {
93        if let ConnectionImpl::Tcp(ref s) = self.0 {
94            Some(s)
95        } else {
96            None
97        }
98    }
99    #[cfg(all(feature = "unix", unix))]
100    #[cfg_attr(docsrs_alt, doc(cfg(all(feature = "unix", unix))))]
101    pub fn try_borrow_unix(&self) -> Option<&UnixStream> {
102        if let ConnectionImpl::Unix(ref s) = self.0 {
103            Some(s)
104        } else {
105            None
106        }
107    }
108    #[cfg(feature = "inetd")]
109    #[cfg_attr(docsrs_alt, doc(cfg(feature = "inetd")))]
110    pub fn try_borrow_stdio(&self) -> Option<(&Stdin, &Stdout)> {
111        if let ConnectionImpl::Stdio(ref i, ref o, ..) = self.0 {
112            Some((i, o))
113        } else {
114            None
115        }
116    }
117}
118
119impl From<TcpStream> for Connection {
120    fn from(s: TcpStream) -> Self {
121        Connection(ConnectionImpl::Tcp(s))
122    }
123}
124#[cfg(all(feature = "unix", unix))]
125#[cfg_attr(docsrs_alt, doc(cfg(all(feature = "unix", unix))))]
126impl From<UnixStream> for Connection {
127    fn from(s: UnixStream) -> Self {
128        Connection(ConnectionImpl::Unix(s))
129    }
130}
131#[cfg(feature = "inetd")]
132#[cfg_attr(docsrs_alt, doc(cfg(feature = "inetd")))]
133impl From<(Stdin, Stdout, Option<Sender<()>>)> for Connection {
134    fn from(s: (Stdin, Stdout, Option<Sender<()>>)) -> Self {
135        Connection(ConnectionImpl::Stdio(s.0, s.1, s.2))
136    }
137}
138
139impl AsyncRead for Connection {
140    #[inline]
141    fn poll_read(
142        self: Pin<&mut Self>,
143        cx: &mut Context<'_>,
144        buf: &mut tokio::io::ReadBuf<'_>,
145    ) -> Poll<std::io::Result<()>> {
146        let q: Pin<&mut ConnectionImpl> = self.project().0;
147        match q.project() {
148            ConnectionImplProj::Tcp(s) => s.poll_read(cx, buf),
149            #[cfg(all(feature = "unix", unix))]
150            ConnectionImplProj::Unix(s) => s.poll_read(cx, buf),
151            #[cfg(feature = "inetd")]
152            ConnectionImplProj::Stdio(s, _, _) => s.poll_read(cx, buf),
153        }
154    }
155}
156
157impl AsyncWrite for Connection {
158    #[inline]
159    fn poll_write(
160        self: Pin<&mut Self>,
161        cx: &mut Context<'_>,
162        buf: &[u8],
163    ) -> Poll<Result<usize, std::io::Error>> {
164        let q: Pin<&mut ConnectionImpl> = self.project().0;
165        match q.project() {
166            ConnectionImplProj::Tcp(s) => s.poll_write(cx, buf),
167            #[cfg(all(feature = "unix", unix))]
168            ConnectionImplProj::Unix(s) => s.poll_write(cx, buf),
169            #[cfg(feature = "inetd")]
170            ConnectionImplProj::Stdio(_, s, _) => s.poll_write(cx, buf),
171        }
172    }
173
174    #[inline]
175    fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), std::io::Error>> {
176        let q: Pin<&mut ConnectionImpl> = self.project().0;
177        match q.project() {
178            ConnectionImplProj::Tcp(s) => s.poll_flush(cx),
179            #[cfg(all(feature = "unix", unix))]
180            ConnectionImplProj::Unix(s) => s.poll_flush(cx),
181            #[cfg(feature = "inetd")]
182            ConnectionImplProj::Stdio(_, s, _) => s.poll_flush(cx),
183        }
184    }
185
186    #[inline]
187    fn poll_shutdown(
188        self: Pin<&mut Self>,
189        cx: &mut Context<'_>,
190    ) -> Poll<Result<(), std::io::Error>> {
191        let q: Pin<&mut ConnectionImpl> = self.project().0;
192        match q.project() {
193            ConnectionImplProj::Tcp(s) => s.poll_shutdown(cx),
194            #[cfg(all(feature = "unix", unix))]
195            ConnectionImplProj::Unix(s) => s.poll_shutdown(cx),
196            #[cfg(feature = "inetd")]
197            ConnectionImplProj::Stdio(_, s, tx) => match s.poll_shutdown(cx) {
198                Poll::Pending => Poll::Pending,
199                Poll::Ready(ret) => {
200                    if let Some(tx) = tx.take() {
201                        if tx.send(()).is_err() {
202                            warn!("stdout wrapper for inetd mode failed to notify the listener to abort listening loop");
203                        } else {
204                            debug!("stdout finished in inetd mode. Aborting the listening loop.");
205                        }
206                    }
207                    Poll::Ready(ret)
208                }
209            },
210        }
211    }
212
213    #[inline]
214    fn poll_write_vectored(
215        self: Pin<&mut Self>,
216        cx: &mut Context<'_>,
217        bufs: &[std::io::IoSlice<'_>],
218    ) -> Poll<Result<usize, std::io::Error>> {
219        let q: Pin<&mut ConnectionImpl> = self.project().0;
220        match q.project() {
221            ConnectionImplProj::Tcp(s) => s.poll_write_vectored(cx, bufs),
222            #[cfg(all(feature = "unix", unix))]
223            ConnectionImplProj::Unix(s) => s.poll_write_vectored(cx, bufs),
224            #[cfg(feature = "inetd")]
225            ConnectionImplProj::Stdio(_, s, _) => s.poll_write_vectored(cx, bufs),
226        }
227    }
228
229    #[inline]
230    fn is_write_vectored(&self) -> bool {
231        match &self.0 {
232            ConnectionImpl::Tcp(s) => s.is_write_vectored(),
233            #[cfg(all(feature = "unix", unix))]
234            ConnectionImpl::Unix(s) => s.is_write_vectored(),
235            #[cfg(feature = "inetd")]
236            ConnectionImpl::Stdio(_, s, _) => s.is_write_vectored(),
237        }
238    }
239}