tokio_listener/
axum07.rs

1use std::{
2    convert::Infallible,
3    future::{poll_fn, IntoFuture},
4    io,
5    marker::PhantomData,
6    pin::Pin,
7    sync::Arc,
8    task::{Context, Poll},
9    time::Duration,
10};
11
12use axum07::{
13    body::Body,
14    extract::{connect_info::Connected, Request},
15    response::Response,
16};
17use futures_util::{pin_mut, FutureExt};
18use hyper1::body::Incoming;
19use hyper_util::{
20    rt::{TokioExecutor, TokioIo},
21    server::conn::auto::Builder,
22};
23use std::future::Future;
24use tokio::sync::watch;
25use tower::{util::Oneshot, ServiceExt};
26use tower_service::Service;
27use tracing::trace;
28
29use crate::{is_connection_error, SomeSocketAddr, SomeSocketAddrClonable};
30
31/// An incoming stream.
32///
33/// Used with [`serve`] and [`IntoMakeServiceWithConnectInfo`].
34///
35/// [`IntoMakeServiceWithConnectInfo`]: axum07::extract::connect_info::IntoMakeServiceWithConnectInfo
36#[derive(Debug)]
37pub struct IncomingStream<'a> {
38    stream: &'a TokioIo<crate::Connection>,
39    remote_addr: SomeSocketAddrClonable,
40}
41
42impl IncomingStream<'_> {
43    /// Returns the local address that this stream is bound to.
44    #[allow(clippy::missing_errors_doc)]
45    pub fn local_addr(&self) -> std::io::Result<SomeSocketAddr> {
46        let q = self.stream.inner();
47        if let Some(a) = q.try_borrow_tcp() {
48            return Ok(SomeSocketAddr::Tcp(a.local_addr()?));
49        }
50        #[cfg(all(feature = "unix", unix))]
51        if let Some(a) = q.try_borrow_unix() {
52            return Ok(SomeSocketAddr::Unix(a.local_addr()?));
53        }
54        #[cfg(feature = "inetd")]
55        if q.try_borrow_stdio().is_some() {
56            return Ok(SomeSocketAddr::Stdio);
57        }
58        Err(std::io::Error::other(
59            "unhandled tokio-listener address type",
60        ))
61    }
62
63    /// Returns the remote address that this stream is bound to.
64    #[must_use]
65    pub fn remote_addr(&self) -> SomeSocketAddrClonable {
66        self.remote_addr.clone()
67    }
68}
69
70impl Connected<IncomingStream<'_>> for SomeSocketAddrClonable {
71    fn connect_info(target: IncomingStream<'_>) -> Self {
72        target.remote_addr()
73    }
74}
75
76/// Future returned by [`serve`].
77pub struct Serve<M, S> {
78    tokio_listener: crate::Listener,
79    make_service: M,
80    _marker: PhantomData<S>,
81}
82
83/// Serve the service with the supplied `tokio_listener`-based listener.
84///
85/// See [`axum07::serve::serve`] for more documentation.
86///
87/// See the following examples in `tokio_listener` project:
88///
89/// * [`clap_axum07.rs`](https://github.com/vi/tokio-listener/blob/main/examples/clap_axum07.rs) for simple example
90/// * [`clap_axum07_advanced.rs`](https://github.com/vi/tokio-listener/blob/main/examples/clap_axum07_advanced.rs) for using incoming connection info and graceful shutdown.
91pub fn serve<M, S>(tokio_listener: crate::Listener, make_service: M) -> Serve<M, S>
92where
93    M: for<'a> Service<IncomingStream<'a>, Error = Infallible, Response = S>,
94    S: Service<Request, Response = Response, Error = Infallible> + Clone + Send + 'static,
95    S::Future: Send,
96{
97    Serve {
98        tokio_listener,
99        make_service,
100        _marker: PhantomData,
101    }
102}
103
104impl<M, S> IntoFuture for Serve<M, S>
105where
106    M: for<'a> Service<IncomingStream<'a>, Error = Infallible, Response = S> + Send + 'static,
107    for<'a> <M as Service<IncomingStream<'a>>>::Future: Send,
108    S: Service<Request, Response = Response, Error = Infallible> + Clone + Send + 'static,
109    S::Future: Send,
110{
111    type Output = io::Result<()>;
112    type IntoFuture = private::ServeFuture;
113
114    fn into_future(self) -> Self::IntoFuture {
115        private::ServeFuture(Box::pin(async move {
116            let Self {
117                mut tokio_listener,
118                mut make_service,
119                _marker: _,
120            } = self;
121
122            loop {
123                let Some((stream, remote_addr)) = tokio_listener_accept(&mut tokio_listener).await
124                else {
125                    if tokio_listener.no_more_connections() {
126                        return Ok(());
127                    }
128                    continue;
129                };
130                let stream = TokioIo::new(stream);
131
132                poll_fn(|cx| make_service.poll_ready(cx))
133                    .await
134                    .unwrap_or_else(|err| match err {});
135
136                let tower_service = make_service
137                    .call(IncomingStream {
138                        stream: &stream,
139                        remote_addr: remote_addr.clonable(),
140                    })
141                    .await
142                    .unwrap_or_else(|err| match err {});
143
144                let hyper_service = TowerToHyperService {
145                    service: tower_service,
146                };
147
148                tokio::spawn(async move {
149                    match Builder::new(TokioExecutor::new())
150                        // upgrades needed for websockets
151                        .serve_connection_with_upgrades(stream, hyper_service)
152                        .await
153                    {
154                        Ok(()) => {}
155                        Err(_err) => {
156                            // This error only appears when the client doesn't send a request and
157                            // terminate the connection.
158                            //
159                            // If client sends one request then terminate connection whenever, it doesn't
160                            // appear.
161                        }
162                    }
163                });
164            }
165        }))
166    }
167}
168
169mod private {
170    use std::{
171        future::Future,
172        io,
173        pin::Pin,
174        task::{Context, Poll},
175    };
176
177    pub struct ServeFuture(pub(super) futures_util::future::BoxFuture<'static, io::Result<()>>);
178
179    impl Future for ServeFuture {
180        type Output = io::Result<()>;
181
182        #[inline]
183        fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
184            self.0.as_mut().poll(cx)
185        }
186    }
187
188    impl std::fmt::Debug for ServeFuture {
189        fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
190            f.debug_struct("ServeFuture").finish_non_exhaustive()
191        }
192    }
193}
194
195#[derive(Debug, Copy, Clone)]
196struct TowerToHyperService<S> {
197    service: S,
198}
199
200impl<S> hyper1::service::Service<Request<Incoming>> for TowerToHyperService<S>
201where
202    S: tower_service::Service<Request> + Clone,
203{
204    type Response = S::Response;
205    type Error = S::Error;
206    type Future = TowerToHyperServiceFuture<S, Request>;
207
208    fn call(&self, req: Request<Incoming>) -> Self::Future {
209        let req = req.map(Body::new);
210        TowerToHyperServiceFuture {
211            future: self.service.clone().oneshot(req),
212        }
213    }
214}
215
216#[pin_project::pin_project]
217struct TowerToHyperServiceFuture<S, R>
218where
219    S: tower_service::Service<R>,
220{
221    #[pin]
222    future: Oneshot<S, R>,
223}
224
225impl<S, R> Future for TowerToHyperServiceFuture<S, R>
226where
227    S: tower_service::Service<R>,
228{
229    type Output = Result<S::Response, S::Error>;
230
231    #[inline]
232    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
233        self.project().future.poll(cx)
234    }
235}
236
237async fn tokio_listener_accept(
238    listener: &mut crate::Listener,
239) -> Option<(crate::Connection, SomeSocketAddr)> {
240    match listener.accept().await {
241        Ok(conn) => Some(conn),
242        Err(e) => {
243            if is_connection_error(&e) || listener.no_more_connections() {
244                return None;
245            }
246
247            // [From `hyper::Server` in 0.14](https://github.com/hyperium/hyper/blob/v0.14.27/src/server/tcp.rs#L186)
248            //
249            // > A possible scenario is that the process has hit the max open files
250            // > allowed, and so trying to accept a new connection will fail with
251            // > `EMFILE`. In some cases, it's preferable to just wait for some time, if
252            // > the application will likely close some files (or connections), and try
253            // > to accept the connection again. If this option is `true`, the error
254            // > will be logged at the `error` level, since it is still a big deal,
255            // > and then the listener will sleep for 1 second.
256            //
257            // hyper allowed customizing this but axum does not.
258            tracing::error!("accept error: {e}");
259            tokio::time::sleep(Duration::from_secs(1)).await;
260            None
261        }
262    }
263}
264
265/// Serve future with graceful shutdown enabled.
266pub struct WithGracefulShutdown<M, S, F> {
267    tokio_listener: crate::Listener,
268    make_service: M,
269    signal: F,
270    _marker: PhantomData<S>,
271}
272
273impl<M, S, F> std::fmt::Debug for WithGracefulShutdown<M, S, F>
274where
275    M: std::fmt::Debug,
276    S: std::fmt::Debug,
277    F: std::fmt::Debug,
278{
279    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
280        let Self {
281            tokio_listener,
282            make_service,
283            signal,
284            _marker: _,
285        } = self;
286
287        f.debug_struct("WithGracefulShutdown")
288            .field("tokio_listener", tokio_listener)
289            .field("make_service", make_service)
290            .field("signal", signal)
291            .finish()
292    }
293}
294
295impl<M, S> std::fmt::Debug for Serve<M, S>
296where
297    M: std::fmt::Debug,
298{
299    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
300        let Self {
301            tokio_listener,
302            make_service,
303            _marker: _,
304        } = self;
305
306        f.debug_struct("Serve")
307            .field("tokio_listener", tokio_listener)
308            .field("make_service", make_service)
309            .finish()
310    }
311}
312
313#[allow(clippy::single_match_else)]
314impl<M, S, F> IntoFuture for WithGracefulShutdown<M, S, F>
315where
316    M: for<'a> Service<IncomingStream<'a>, Error = Infallible, Response = S> + Send + 'static,
317    for<'a> <M as Service<IncomingStream<'a>>>::Future: Send,
318    S: Service<Request, Response = Response, Error = Infallible> + Clone + Send + 'static,
319    S::Future: Send,
320    F: Future<Output = ()> + Send + 'static,
321{
322    type Output = io::Result<()>;
323    type IntoFuture = private::ServeFuture;
324
325    fn into_future(self) -> Self::IntoFuture {
326        let Self {
327            mut tokio_listener,
328            mut make_service,
329            signal,
330            _marker: _,
331        } = self;
332
333        let (signal_tx, signal_rx) = watch::channel(());
334        let signal_tx = Arc::new(signal_tx);
335        tokio::spawn(async move {
336            signal.await;
337            tracing::trace!("received graceful shutdown signal. Telling tasks to shutdown");
338            drop(signal_rx);
339        });
340
341        let (close_tx, close_rx) = watch::channel(());
342
343        private::ServeFuture(Box::pin(async move {
344            loop {
345                let (stream, remote_addr) = tokio::select! {
346                    conn = tokio_listener_accept(&mut tokio_listener) => {
347                        match conn {
348                            Some(conn) => conn,
349                            None => {
350                                if tokio_listener.no_more_connections() {
351                                    break;
352                                }
353                                continue
354                            }
355                        }
356                    }
357                    () = signal_tx.closed() => {
358                        trace!("signal received, not accepting new connections");
359                        break;
360                    }
361                };
362                let stream = TokioIo::new(stream);
363
364                trace!("connection {remote_addr} accepted");
365
366                poll_fn(|cx| make_service.poll_ready(cx))
367                    .await
368                    .unwrap_or_else(|err| match err {});
369
370                let tower_service = make_service
371                    .call(IncomingStream {
372                        stream: &stream,
373                        remote_addr: remote_addr.clonable(),
374                    })
375                    .await
376                    .unwrap_or_else(|err| match err {});
377
378                let hyper_service = TowerToHyperService {
379                    service: tower_service,
380                };
381
382                let signal_tx = Arc::clone(&signal_tx);
383
384                let close_rx = close_rx.clone();
385
386                tokio::spawn(async move {
387                    let builder = Builder::new(TokioExecutor::new());
388                    let conn = builder.serve_connection_with_upgrades(stream, hyper_service);
389                    pin_mut!(conn);
390
391                    let signal_closed = signal_tx.closed().fuse();
392                    pin_mut!(signal_closed);
393
394                    loop {
395                        tokio::select! {
396                            result = conn.as_mut() => {
397                                if let Err(err) = result {
398                                    trace!("failed to serve connection: {err:#}");
399                                }
400                                break;
401                            }
402                            () = &mut signal_closed => {
403                                trace!("signal received in task, starting graceful shutdown");
404                                conn.as_mut().graceful_shutdown();
405                            }
406                        }
407                    }
408
409                    trace!("a connection closed");
410
411                    drop(close_rx);
412                });
413            }
414
415            drop(close_rx);
416            drop(tokio_listener);
417
418            trace!(
419                "waiting for {} task(s) to finish",
420                close_tx.receiver_count()
421            );
422            close_tx.closed().await;
423
424            Ok(())
425        }))
426    }
427}
428
429impl<M, S> Serve<M, S> {
430    /// Prepares a server to handle graceful shutdown when the provided future completes.
431    ///
432    /// See [the original documentation][1] for the example.
433    ///
434    /// [1]: axum07::serve::Serve::with_graceful_shutdown
435    pub fn with_graceful_shutdown<F>(self, signal: F) -> WithGracefulShutdown<M, S, F>
436    where
437        F: Future<Output = ()> + Send + 'static,
438    {
439        WithGracefulShutdown {
440            tokio_listener: self.tokio_listener,
441            make_service: self.make_service,
442            signal,
443            _marker: PhantomData,
444        }
445    }
446}