1use std::{
2 convert::Infallible,
3 future::{poll_fn, IntoFuture},
4 io,
5 marker::PhantomData,
6 pin::Pin,
7 sync::Arc,
8 task::{Context, Poll},
9 time::Duration,
10};
11
12use axum07::{
13 body::Body,
14 extract::{connect_info::Connected, Request},
15 response::Response,
16};
17use futures_util::{pin_mut, FutureExt};
18use hyper1::body::Incoming;
19use hyper_util::{
20 rt::{TokioExecutor, TokioIo},
21 server::conn::auto::Builder,
22};
23use std::future::Future;
24use tokio::sync::watch;
25use tower::{util::Oneshot, ServiceExt};
26use tower_service::Service;
27use tracing::trace;
28
29use crate::{is_connection_error, SomeSocketAddr, SomeSocketAddrClonable};
30
31#[derive(Debug)]
37pub struct IncomingStream<'a> {
38 stream: &'a TokioIo<crate::Connection>,
39 remote_addr: SomeSocketAddrClonable,
40}
41
42impl IncomingStream<'_> {
43 #[allow(clippy::missing_errors_doc)]
45 pub fn local_addr(&self) -> std::io::Result<SomeSocketAddr> {
46 let q = self.stream.inner();
47 if let Some(a) = q.try_borrow_tcp() {
48 return Ok(SomeSocketAddr::Tcp(a.local_addr()?));
49 }
50 #[cfg(all(feature = "unix", unix))]
51 if let Some(a) = q.try_borrow_unix() {
52 return Ok(SomeSocketAddr::Unix(a.local_addr()?));
53 }
54 #[cfg(feature = "inetd")]
55 if q.try_borrow_stdio().is_some() {
56 return Ok(SomeSocketAddr::Stdio);
57 }
58 Err(std::io::Error::other(
59 "unhandled tokio-listener address type",
60 ))
61 }
62
63 #[must_use]
65 pub fn remote_addr(&self) -> SomeSocketAddrClonable {
66 self.remote_addr.clone()
67 }
68}
69
70impl Connected<IncomingStream<'_>> for SomeSocketAddrClonable {
71 fn connect_info(target: IncomingStream<'_>) -> Self {
72 target.remote_addr()
73 }
74}
75
76pub struct Serve<M, S> {
78 tokio_listener: crate::Listener,
79 make_service: M,
80 _marker: PhantomData<S>,
81}
82
83pub fn serve<M, S>(tokio_listener: crate::Listener, make_service: M) -> Serve<M, S>
92where
93 M: for<'a> Service<IncomingStream<'a>, Error = Infallible, Response = S>,
94 S: Service<Request, Response = Response, Error = Infallible> + Clone + Send + 'static,
95 S::Future: Send,
96{
97 Serve {
98 tokio_listener,
99 make_service,
100 _marker: PhantomData,
101 }
102}
103
104impl<M, S> IntoFuture for Serve<M, S>
105where
106 M: for<'a> Service<IncomingStream<'a>, Error = Infallible, Response = S> + Send + 'static,
107 for<'a> <M as Service<IncomingStream<'a>>>::Future: Send,
108 S: Service<Request, Response = Response, Error = Infallible> + Clone + Send + 'static,
109 S::Future: Send,
110{
111 type Output = io::Result<()>;
112 type IntoFuture = private::ServeFuture;
113
114 fn into_future(self) -> Self::IntoFuture {
115 private::ServeFuture(Box::pin(async move {
116 let Self {
117 mut tokio_listener,
118 mut make_service,
119 _marker: _,
120 } = self;
121
122 loop {
123 let Some((stream, remote_addr)) = tokio_listener_accept(&mut tokio_listener).await
124 else {
125 if tokio_listener.no_more_connections() {
126 return Ok(());
127 }
128 continue;
129 };
130 let stream = TokioIo::new(stream);
131
132 poll_fn(|cx| make_service.poll_ready(cx))
133 .await
134 .unwrap_or_else(|err| match err {});
135
136 let tower_service = make_service
137 .call(IncomingStream {
138 stream: &stream,
139 remote_addr: remote_addr.clonable(),
140 })
141 .await
142 .unwrap_or_else(|err| match err {});
143
144 let hyper_service = TowerToHyperService {
145 service: tower_service,
146 };
147
148 tokio::spawn(async move {
149 match Builder::new(TokioExecutor::new())
150 .serve_connection_with_upgrades(stream, hyper_service)
152 .await
153 {
154 Ok(()) => {}
155 Err(_err) => {
156 }
162 }
163 });
164 }
165 }))
166 }
167}
168
169mod private {
170 use std::{
171 future::Future,
172 io,
173 pin::Pin,
174 task::{Context, Poll},
175 };
176
177 pub struct ServeFuture(pub(super) futures_util::future::BoxFuture<'static, io::Result<()>>);
178
179 impl Future for ServeFuture {
180 type Output = io::Result<()>;
181
182 #[inline]
183 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
184 self.0.as_mut().poll(cx)
185 }
186 }
187
188 impl std::fmt::Debug for ServeFuture {
189 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
190 f.debug_struct("ServeFuture").finish_non_exhaustive()
191 }
192 }
193}
194
195#[derive(Debug, Copy, Clone)]
196struct TowerToHyperService<S> {
197 service: S,
198}
199
200impl<S> hyper1::service::Service<Request<Incoming>> for TowerToHyperService<S>
201where
202 S: tower_service::Service<Request> + Clone,
203{
204 type Response = S::Response;
205 type Error = S::Error;
206 type Future = TowerToHyperServiceFuture<S, Request>;
207
208 fn call(&self, req: Request<Incoming>) -> Self::Future {
209 let req = req.map(Body::new);
210 TowerToHyperServiceFuture {
211 future: self.service.clone().oneshot(req),
212 }
213 }
214}
215
216#[pin_project::pin_project]
217struct TowerToHyperServiceFuture<S, R>
218where
219 S: tower_service::Service<R>,
220{
221 #[pin]
222 future: Oneshot<S, R>,
223}
224
225impl<S, R> Future for TowerToHyperServiceFuture<S, R>
226where
227 S: tower_service::Service<R>,
228{
229 type Output = Result<S::Response, S::Error>;
230
231 #[inline]
232 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
233 self.project().future.poll(cx)
234 }
235}
236
237async fn tokio_listener_accept(
238 listener: &mut crate::Listener,
239) -> Option<(crate::Connection, SomeSocketAddr)> {
240 match listener.accept().await {
241 Ok(conn) => Some(conn),
242 Err(e) => {
243 if is_connection_error(&e) || listener.no_more_connections() {
244 return None;
245 }
246
247 tracing::error!("accept error: {e}");
259 tokio::time::sleep(Duration::from_secs(1)).await;
260 None
261 }
262 }
263}
264
265pub struct WithGracefulShutdown<M, S, F> {
267 tokio_listener: crate::Listener,
268 make_service: M,
269 signal: F,
270 _marker: PhantomData<S>,
271}
272
273impl<M, S, F> std::fmt::Debug for WithGracefulShutdown<M, S, F>
274where
275 M: std::fmt::Debug,
276 S: std::fmt::Debug,
277 F: std::fmt::Debug,
278{
279 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
280 let Self {
281 tokio_listener,
282 make_service,
283 signal,
284 _marker: _,
285 } = self;
286
287 f.debug_struct("WithGracefulShutdown")
288 .field("tokio_listener", tokio_listener)
289 .field("make_service", make_service)
290 .field("signal", signal)
291 .finish()
292 }
293}
294
295impl<M, S> std::fmt::Debug for Serve<M, S>
296where
297 M: std::fmt::Debug,
298{
299 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
300 let Self {
301 tokio_listener,
302 make_service,
303 _marker: _,
304 } = self;
305
306 f.debug_struct("Serve")
307 .field("tokio_listener", tokio_listener)
308 .field("make_service", make_service)
309 .finish()
310 }
311}
312
313#[allow(clippy::single_match_else)]
314impl<M, S, F> IntoFuture for WithGracefulShutdown<M, S, F>
315where
316 M: for<'a> Service<IncomingStream<'a>, Error = Infallible, Response = S> + Send + 'static,
317 for<'a> <M as Service<IncomingStream<'a>>>::Future: Send,
318 S: Service<Request, Response = Response, Error = Infallible> + Clone + Send + 'static,
319 S::Future: Send,
320 F: Future<Output = ()> + Send + 'static,
321{
322 type Output = io::Result<()>;
323 type IntoFuture = private::ServeFuture;
324
325 fn into_future(self) -> Self::IntoFuture {
326 let Self {
327 mut tokio_listener,
328 mut make_service,
329 signal,
330 _marker: _,
331 } = self;
332
333 let (signal_tx, signal_rx) = watch::channel(());
334 let signal_tx = Arc::new(signal_tx);
335 tokio::spawn(async move {
336 signal.await;
337 tracing::trace!("received graceful shutdown signal. Telling tasks to shutdown");
338 drop(signal_rx);
339 });
340
341 let (close_tx, close_rx) = watch::channel(());
342
343 private::ServeFuture(Box::pin(async move {
344 loop {
345 let (stream, remote_addr) = tokio::select! {
346 conn = tokio_listener_accept(&mut tokio_listener) => {
347 match conn {
348 Some(conn) => conn,
349 None => {
350 if tokio_listener.no_more_connections() {
351 break;
352 }
353 continue
354 }
355 }
356 }
357 () = signal_tx.closed() => {
358 trace!("signal received, not accepting new connections");
359 break;
360 }
361 };
362 let stream = TokioIo::new(stream);
363
364 trace!("connection {remote_addr} accepted");
365
366 poll_fn(|cx| make_service.poll_ready(cx))
367 .await
368 .unwrap_or_else(|err| match err {});
369
370 let tower_service = make_service
371 .call(IncomingStream {
372 stream: &stream,
373 remote_addr: remote_addr.clonable(),
374 })
375 .await
376 .unwrap_or_else(|err| match err {});
377
378 let hyper_service = TowerToHyperService {
379 service: tower_service,
380 };
381
382 let signal_tx = Arc::clone(&signal_tx);
383
384 let close_rx = close_rx.clone();
385
386 tokio::spawn(async move {
387 let builder = Builder::new(TokioExecutor::new());
388 let conn = builder.serve_connection_with_upgrades(stream, hyper_service);
389 pin_mut!(conn);
390
391 let signal_closed = signal_tx.closed().fuse();
392 pin_mut!(signal_closed);
393
394 loop {
395 tokio::select! {
396 result = conn.as_mut() => {
397 if let Err(err) = result {
398 trace!("failed to serve connection: {err:#}");
399 }
400 break;
401 }
402 () = &mut signal_closed => {
403 trace!("signal received in task, starting graceful shutdown");
404 conn.as_mut().graceful_shutdown();
405 }
406 }
407 }
408
409 trace!("a connection closed");
410
411 drop(close_rx);
412 });
413 }
414
415 drop(close_rx);
416 drop(tokio_listener);
417
418 trace!(
419 "waiting for {} task(s) to finish",
420 close_tx.receiver_count()
421 );
422 close_tx.closed().await;
423
424 Ok(())
425 }))
426 }
427}
428
429impl<M, S> Serve<M, S> {
430 pub fn with_graceful_shutdown<F>(self, signal: F) -> WithGracefulShutdown<M, S, F>
436 where
437 F: Future<Output = ()> + Send + 'static,
438 {
439 WithGracefulShutdown {
440 tokio_listener: self.tokio_listener,
441 make_service: self.make_service,
442 signal,
443 _marker: PhantomData,
444 }
445 }
446}