1#[allow(unused_imports)]
2use std::{
3 ffi::c_int,
4 fmt::Display,
5 net::SocketAddr,
6 path::PathBuf,
7 pin::Pin,
8 str::FromStr,
9 sync::Arc,
10 task::{ready, Context, Poll},
11 time::Duration,
12};
13
14use pin_project::pin_project;
15use tokio::{
16 io::{AsyncRead, AsyncWrite, Stdin, Stdout},
17 net::TcpStream,
18 sync::oneshot::Sender,
19};
20use tracing::{debug, warn};
21
22#[cfg(unix)]
23use tokio::net::UnixStream;
24
25#[pin_project]
29pub struct Connection(#[pin] pub(crate) ConnectionImpl);
30
31impl std::fmt::Debug for Connection {
32 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
33 match self.0 {
34 ConnectionImpl::Tcp(_) => f.write_str("Connection(tcp)"),
35 #[cfg(all(feature = "unix", unix))]
36 ConnectionImpl::Unix(_) => f.write_str("Connection(unix)"),
37 #[cfg(feature = "inetd")]
38 ConnectionImpl::Stdio(_, _, _) => f.write_str("Connection(stdio)"),
39 }
40 }
41}
42
43#[derive(Debug)]
44#[pin_project(project = ConnectionImplProj)]
45pub(crate) enum ConnectionImpl {
46 Tcp(#[pin] TcpStream),
47 #[cfg(all(feature = "unix", unix))]
48 Unix(#[pin] UnixStream),
49 #[cfg(feature = "inetd")]
50 Stdio(
51 #[pin] tokio::io::Stdin,
52 #[pin] tokio::io::Stdout,
53 Option<Sender<()>>,
54 ),
55}
56
57#[allow(missing_docs)]
58#[allow(clippy::missing_errors_doc)]
59impl Connection {
60 pub fn try_into_tcp(self) -> Result<TcpStream, Self> {
61 if let ConnectionImpl::Tcp(s) = self.0 {
62 Ok(s)
63 } else {
64 Err(self)
65 }
66 }
67 #[cfg(all(feature = "unix", unix))]
68 #[cfg_attr(docsrs_alt, doc(cfg(all(feature = "unix", unix))))]
69 pub fn try_into_unix(self) -> Result<UnixStream, Self> {
70 if let ConnectionImpl::Unix(s) = self.0 {
71 Ok(s)
72 } else {
73 Err(self)
74 }
75 }
76 #[cfg(feature = "inetd")]
77 #[cfg_attr(docsrs_alt, doc(cfg(feature = "inetd")))]
78 pub fn try_into_stdio(self) -> Result<(Stdin, Stdout, Option<Sender<()>>), Self> {
85 if let ConnectionImpl::Stdio(i, o, f) = self.0 {
86 Ok((i, o, f))
87 } else {
88 Err(self)
89 }
90 }
91
92 pub fn try_borrow_tcp(&self) -> Option<&TcpStream> {
93 if let ConnectionImpl::Tcp(ref s) = self.0 {
94 Some(s)
95 } else {
96 None
97 }
98 }
99 #[cfg(all(feature = "unix", unix))]
100 #[cfg_attr(docsrs_alt, doc(cfg(all(feature = "unix", unix))))]
101 pub fn try_borrow_unix(&self) -> Option<&UnixStream> {
102 if let ConnectionImpl::Unix(ref s) = self.0 {
103 Some(s)
104 } else {
105 None
106 }
107 }
108 #[cfg(feature = "inetd")]
109 #[cfg_attr(docsrs_alt, doc(cfg(feature = "inetd")))]
110 pub fn try_borrow_stdio(&self) -> Option<(&Stdin, &Stdout)> {
111 if let ConnectionImpl::Stdio(ref i, ref o, ..) = self.0 {
112 Some((i, o))
113 } else {
114 None
115 }
116 }
117}
118
119impl From<TcpStream> for Connection {
120 fn from(s: TcpStream) -> Self {
121 Connection(ConnectionImpl::Tcp(s))
122 }
123}
124#[cfg(all(feature = "unix", unix))]
125#[cfg_attr(docsrs_alt, doc(cfg(all(feature = "unix", unix))))]
126impl From<UnixStream> for Connection {
127 fn from(s: UnixStream) -> Self {
128 Connection(ConnectionImpl::Unix(s))
129 }
130}
131#[cfg(feature = "inetd")]
132#[cfg_attr(docsrs_alt, doc(cfg(feature = "inetd")))]
133impl From<(Stdin, Stdout, Option<Sender<()>>)> for Connection {
134 fn from(s: (Stdin, Stdout, Option<Sender<()>>)) -> Self {
135 Connection(ConnectionImpl::Stdio(s.0, s.1, s.2))
136 }
137}
138
139impl AsyncRead for Connection {
140 #[inline]
141 fn poll_read(
142 self: Pin<&mut Self>,
143 cx: &mut Context<'_>,
144 buf: &mut tokio::io::ReadBuf<'_>,
145 ) -> Poll<std::io::Result<()>> {
146 let q: Pin<&mut ConnectionImpl> = self.project().0;
147 match q.project() {
148 ConnectionImplProj::Tcp(s) => s.poll_read(cx, buf),
149 #[cfg(all(feature = "unix", unix))]
150 ConnectionImplProj::Unix(s) => s.poll_read(cx, buf),
151 #[cfg(feature = "inetd")]
152 ConnectionImplProj::Stdio(s, _, _) => s.poll_read(cx, buf),
153 }
154 }
155}
156
157impl AsyncWrite for Connection {
158 #[inline]
159 fn poll_write(
160 self: Pin<&mut Self>,
161 cx: &mut Context<'_>,
162 buf: &[u8],
163 ) -> Poll<Result<usize, std::io::Error>> {
164 let q: Pin<&mut ConnectionImpl> = self.project().0;
165 match q.project() {
166 ConnectionImplProj::Tcp(s) => s.poll_write(cx, buf),
167 #[cfg(all(feature = "unix", unix))]
168 ConnectionImplProj::Unix(s) => s.poll_write(cx, buf),
169 #[cfg(feature = "inetd")]
170 ConnectionImplProj::Stdio(_, s, _) => s.poll_write(cx, buf),
171 }
172 }
173
174 #[inline]
175 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), std::io::Error>> {
176 let q: Pin<&mut ConnectionImpl> = self.project().0;
177 match q.project() {
178 ConnectionImplProj::Tcp(s) => s.poll_flush(cx),
179 #[cfg(all(feature = "unix", unix))]
180 ConnectionImplProj::Unix(s) => s.poll_flush(cx),
181 #[cfg(feature = "inetd")]
182 ConnectionImplProj::Stdio(_, s, _) => s.poll_flush(cx),
183 }
184 }
185
186 #[inline]
187 fn poll_shutdown(
188 self: Pin<&mut Self>,
189 cx: &mut Context<'_>,
190 ) -> Poll<Result<(), std::io::Error>> {
191 let q: Pin<&mut ConnectionImpl> = self.project().0;
192 match q.project() {
193 ConnectionImplProj::Tcp(s) => s.poll_shutdown(cx),
194 #[cfg(all(feature = "unix", unix))]
195 ConnectionImplProj::Unix(s) => s.poll_shutdown(cx),
196 #[cfg(feature = "inetd")]
197 ConnectionImplProj::Stdio(_, s, tx) => match s.poll_shutdown(cx) {
198 Poll::Pending => Poll::Pending,
199 Poll::Ready(ret) => {
200 if let Some(tx) = tx.take() {
201 if tx.send(()).is_err() {
202 warn!("stdout wrapper for inetd mode failed to notify the listener to abort listening loop");
203 } else {
204 debug!("stdout finished in inetd mode. Aborting the listening loop.");
205 }
206 }
207 Poll::Ready(ret)
208 }
209 },
210 }
211 }
212
213 #[inline]
214 fn poll_write_vectored(
215 self: Pin<&mut Self>,
216 cx: &mut Context<'_>,
217 bufs: &[std::io::IoSlice<'_>],
218 ) -> Poll<Result<usize, std::io::Error>> {
219 let q: Pin<&mut ConnectionImpl> = self.project().0;
220 match q.project() {
221 ConnectionImplProj::Tcp(s) => s.poll_write_vectored(cx, bufs),
222 #[cfg(all(feature = "unix", unix))]
223 ConnectionImplProj::Unix(s) => s.poll_write_vectored(cx, bufs),
224 #[cfg(feature = "inetd")]
225 ConnectionImplProj::Stdio(_, s, _) => s.poll_write_vectored(cx, bufs),
226 }
227 }
228
229 #[inline]
230 fn is_write_vectored(&self) -> bool {
231 match &self.0 {
232 ConnectionImpl::Tcp(s) => s.is_write_vectored(),
233 #[cfg(all(feature = "unix", unix))]
234 ConnectionImpl::Unix(s) => s.is_write_vectored(),
235 #[cfg(feature = "inetd")]
236 ConnectionImpl::Stdio(_, s, _) => s.is_write_vectored(),
237 }
238 }
239}