axum/
serve.rs

1//! Serve services.
2
3use std::{
4    convert::Infallible,
5    fmt::Debug,
6    future::{poll_fn, Future, IntoFuture},
7    io,
8    marker::PhantomData,
9    net::SocketAddr,
10    sync::Arc,
11    time::Duration,
12};
13
14use axum_core::{body::Body, extract::Request, response::Response};
15use futures_util::{pin_mut, FutureExt};
16use hyper::body::Incoming;
17use hyper_util::rt::{TokioExecutor, TokioIo};
18#[cfg(any(feature = "http1", feature = "http2"))]
19use hyper_util::{server::conn::auto::Builder, service::TowerToHyperService};
20use tokio::{
21    net::{TcpListener, TcpStream},
22    sync::watch,
23};
24use tower::ServiceExt as _;
25use tower_service::Service;
26
27/// Serve the service with the supplied listener.
28///
29/// This method of running a service is intentionally simple and doesn't support any configuration.
30/// Use hyper or hyper-util if you need configuration.
31///
32/// It supports both HTTP/1 as well as HTTP/2.
33///
34/// # Examples
35///
36/// Serving a [`Router`]:
37///
38/// ```
39/// use axum::{Router, routing::get};
40///
41/// # async {
42/// let router = Router::new().route("/", get(|| async { "Hello, World!" }));
43///
44/// let listener = tokio::net::TcpListener::bind("0.0.0.0:3000").await.unwrap();
45/// axum::serve(listener, router).await.unwrap();
46/// # };
47/// ```
48///
49/// See also [`Router::into_make_service_with_connect_info`].
50///
51/// Serving a [`MethodRouter`]:
52///
53/// ```
54/// use axum::routing::get;
55///
56/// # async {
57/// let router = get(|| async { "Hello, World!" });
58///
59/// let listener = tokio::net::TcpListener::bind("0.0.0.0:3000").await.unwrap();
60/// axum::serve(listener, router).await.unwrap();
61/// # };
62/// ```
63///
64/// See also [`MethodRouter::into_make_service_with_connect_info`].
65///
66/// Serving a [`Handler`]:
67///
68/// ```
69/// use axum::handler::HandlerWithoutStateExt;
70///
71/// # async {
72/// async fn handler() -> &'static str {
73///     "Hello, World!"
74/// }
75///
76/// let listener = tokio::net::TcpListener::bind("0.0.0.0:3000").await.unwrap();
77/// axum::serve(listener, handler.into_make_service()).await.unwrap();
78/// # };
79/// ```
80///
81/// See also [`HandlerWithoutStateExt::into_make_service_with_connect_info`] and
82/// [`HandlerService::into_make_service_with_connect_info`].
83///
84/// [`Router`]: crate::Router
85/// [`Router::into_make_service_with_connect_info`]: crate::Router::into_make_service_with_connect_info
86/// [`MethodRouter`]: crate::routing::MethodRouter
87/// [`MethodRouter::into_make_service_with_connect_info`]: crate::routing::MethodRouter::into_make_service_with_connect_info
88/// [`Handler`]: crate::handler::Handler
89/// [`HandlerWithoutStateExt::into_make_service_with_connect_info`]: crate::handler::HandlerWithoutStateExt::into_make_service_with_connect_info
90/// [`HandlerService::into_make_service_with_connect_info`]: crate::handler::HandlerService::into_make_service_with_connect_info
91#[cfg(all(feature = "tokio", any(feature = "http1", feature = "http2")))]
92pub fn serve<M, S>(tcp_listener: TcpListener, make_service: M) -> Serve<M, S>
93where
94    M: for<'a> Service<IncomingStream<'a>, Error = Infallible, Response = S>,
95    S: Service<Request, Response = Response, Error = Infallible> + Clone + Send + 'static,
96    S::Future: Send,
97{
98    Serve {
99        tcp_listener,
100        make_service,
101        tcp_nodelay: None,
102        _marker: PhantomData,
103    }
104}
105
106/// Future returned by [`serve`].
107#[cfg(all(feature = "tokio", any(feature = "http1", feature = "http2")))]
108#[must_use = "futures must be awaited or polled"]
109pub struct Serve<M, S> {
110    tcp_listener: TcpListener,
111    make_service: M,
112    tcp_nodelay: Option<bool>,
113    _marker: PhantomData<S>,
114}
115
116#[cfg(all(feature = "tokio", any(feature = "http1", feature = "http2")))]
117impl<M, S> Serve<M, S> {
118    /// Prepares a server to handle graceful shutdown when the provided future completes.
119    ///
120    /// # Example
121    ///
122    /// ```
123    /// use axum::{Router, routing::get};
124    ///
125    /// # async {
126    /// let router = Router::new().route("/", get(|| async { "Hello, World!" }));
127    ///
128    /// let listener = tokio::net::TcpListener::bind("0.0.0.0:3000").await.unwrap();
129    /// axum::serve(listener, router)
130    ///     .with_graceful_shutdown(shutdown_signal())
131    ///     .await
132    ///     .unwrap();
133    /// # };
134    ///
135    /// async fn shutdown_signal() {
136    ///     // ...
137    /// }
138    /// ```
139    pub fn with_graceful_shutdown<F>(self, signal: F) -> WithGracefulShutdown<M, S, F>
140    where
141        F: Future<Output = ()> + Send + 'static,
142    {
143        WithGracefulShutdown {
144            tcp_listener: self.tcp_listener,
145            make_service: self.make_service,
146            signal,
147            tcp_nodelay: self.tcp_nodelay,
148            _marker: PhantomData,
149        }
150    }
151
152    /// Instructs the server to set the value of the `TCP_NODELAY` option on every accepted connection.
153    ///
154    /// See also [`TcpStream::set_nodelay`].
155    ///
156    /// # Example
157    /// ```
158    /// use axum::{Router, routing::get};
159    ///
160    /// # async {
161    /// let router = Router::new().route("/", get(|| async { "Hello, World!" }));
162    ///
163    /// let listener = tokio::net::TcpListener::bind("0.0.0.0:3000").await.unwrap();
164    /// axum::serve(listener, router)
165    ///     .tcp_nodelay(true)
166    ///     .await
167    ///     .unwrap();
168    /// # };
169    /// ```
170    pub fn tcp_nodelay(self, nodelay: bool) -> Self {
171        Self {
172            tcp_nodelay: Some(nodelay),
173            ..self
174        }
175    }
176
177    /// Returns the local address this server is bound to.
178    pub fn local_addr(&self) -> io::Result<SocketAddr> {
179        self.tcp_listener.local_addr()
180    }
181}
182
183#[cfg(all(feature = "tokio", any(feature = "http1", feature = "http2")))]
184impl<M, S> Debug for Serve<M, S>
185where
186    M: Debug,
187{
188    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
189        let Self {
190            tcp_listener,
191            make_service,
192            tcp_nodelay,
193            _marker: _,
194        } = self;
195
196        f.debug_struct("Serve")
197            .field("tcp_listener", tcp_listener)
198            .field("make_service", make_service)
199            .field("tcp_nodelay", tcp_nodelay)
200            .finish()
201    }
202}
203
204#[cfg(all(feature = "tokio", any(feature = "http1", feature = "http2")))]
205impl<M, S> IntoFuture for Serve<M, S>
206where
207    M: for<'a> Service<IncomingStream<'a>, Error = Infallible, Response = S> + Send + 'static,
208    for<'a> <M as Service<IncomingStream<'a>>>::Future: Send,
209    S: Service<Request, Response = Response, Error = Infallible> + Clone + Send + 'static,
210    S::Future: Send,
211{
212    type Output = io::Result<()>;
213    type IntoFuture = private::ServeFuture;
214
215    fn into_future(self) -> Self::IntoFuture {
216        private::ServeFuture(Box::pin(async move {
217            let Self {
218                tcp_listener,
219                mut make_service,
220                tcp_nodelay,
221                _marker: _,
222            } = self;
223
224            loop {
225                let (tcp_stream, remote_addr) = match tcp_accept(&tcp_listener).await {
226                    Some(conn) => conn,
227                    None => continue,
228                };
229
230                if let Some(nodelay) = tcp_nodelay {
231                    if let Err(err) = tcp_stream.set_nodelay(nodelay) {
232                        trace!("failed to set TCP_NODELAY on incoming connection: {err:#}");
233                    }
234                }
235
236                let tcp_stream = TokioIo::new(tcp_stream);
237
238                poll_fn(|cx| make_service.poll_ready(cx))
239                    .await
240                    .unwrap_or_else(|err| match err {});
241
242                let tower_service = make_service
243                    .call(IncomingStream {
244                        tcp_stream: &tcp_stream,
245                        remote_addr,
246                    })
247                    .await
248                    .unwrap_or_else(|err| match err {})
249                    .map_request(|req: Request<Incoming>| req.map(Body::new));
250
251                let hyper_service = TowerToHyperService::new(tower_service);
252
253                tokio::spawn(async move {
254                    match Builder::new(TokioExecutor::new())
255                        // upgrades needed for websockets
256                        .serve_connection_with_upgrades(tcp_stream, hyper_service)
257                        .await
258                    {
259                        Ok(()) => {}
260                        Err(_err) => {
261                            // This error only appears when the client doesn't send a request and
262                            // terminate the connection.
263                            //
264                            // If client sends one request then terminate connection whenever, it doesn't
265                            // appear.
266                        }
267                    }
268                });
269            }
270        }))
271    }
272}
273
274/// Serve future with graceful shutdown enabled.
275#[cfg(all(feature = "tokio", any(feature = "http1", feature = "http2")))]
276#[must_use = "futures must be awaited or polled"]
277pub struct WithGracefulShutdown<M, S, F> {
278    tcp_listener: TcpListener,
279    make_service: M,
280    signal: F,
281    tcp_nodelay: Option<bool>,
282    _marker: PhantomData<S>,
283}
284
285impl<M, S, F> WithGracefulShutdown<M, S, F> {
286    /// Instructs the server to set the value of the `TCP_NODELAY` option on every accepted connection.
287    ///
288    /// See also [`TcpStream::set_nodelay`].
289    ///
290    /// # Example
291    /// ```
292    /// use axum::{Router, routing::get};
293    ///
294    /// # async {
295    /// let router = Router::new().route("/", get(|| async { "Hello, World!" }));
296    ///
297    /// let listener = tokio::net::TcpListener::bind("0.0.0.0:3000").await.unwrap();
298    /// axum::serve(listener, router)
299    ///     .with_graceful_shutdown(shutdown_signal())
300    ///     .tcp_nodelay(true)
301    ///     .await
302    ///     .unwrap();
303    /// # };
304    ///
305    /// async fn shutdown_signal() {
306    ///     // ...
307    /// }
308    /// ```
309    pub fn tcp_nodelay(self, nodelay: bool) -> Self {
310        Self {
311            tcp_nodelay: Some(nodelay),
312            ..self
313        }
314    }
315
316    /// Returns the local address this server is bound to.
317    pub fn local_addr(&self) -> io::Result<SocketAddr> {
318        self.tcp_listener.local_addr()
319    }
320}
321
322#[cfg(all(feature = "tokio", any(feature = "http1", feature = "http2")))]
323impl<M, S, F> Debug for WithGracefulShutdown<M, S, F>
324where
325    M: Debug,
326    S: Debug,
327    F: Debug,
328{
329    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
330        let Self {
331            tcp_listener,
332            make_service,
333            signal,
334            tcp_nodelay,
335            _marker: _,
336        } = self;
337
338        f.debug_struct("WithGracefulShutdown")
339            .field("tcp_listener", tcp_listener)
340            .field("make_service", make_service)
341            .field("signal", signal)
342            .field("tcp_nodelay", tcp_nodelay)
343            .finish()
344    }
345}
346
347#[cfg(all(feature = "tokio", any(feature = "http1", feature = "http2")))]
348impl<M, S, F> IntoFuture for WithGracefulShutdown<M, S, F>
349where
350    M: for<'a> Service<IncomingStream<'a>, Error = Infallible, Response = S> + Send + 'static,
351    for<'a> <M as Service<IncomingStream<'a>>>::Future: Send,
352    S: Service<Request, Response = Response, Error = Infallible> + Clone + Send + 'static,
353    S::Future: Send,
354    F: Future<Output = ()> + Send + 'static,
355{
356    type Output = io::Result<()>;
357    type IntoFuture = private::ServeFuture;
358
359    fn into_future(self) -> Self::IntoFuture {
360        let Self {
361            tcp_listener,
362            mut make_service,
363            signal,
364            tcp_nodelay,
365            _marker: _,
366        } = self;
367
368        let (signal_tx, signal_rx) = watch::channel(());
369        let signal_tx = Arc::new(signal_tx);
370        tokio::spawn(async move {
371            signal.await;
372            trace!("received graceful shutdown signal. Telling tasks to shutdown");
373            drop(signal_rx);
374        });
375
376        let (close_tx, close_rx) = watch::channel(());
377
378        private::ServeFuture(Box::pin(async move {
379            loop {
380                let (tcp_stream, remote_addr) = tokio::select! {
381                    conn = tcp_accept(&tcp_listener) => {
382                        match conn {
383                            Some(conn) => conn,
384                            None => continue,
385                        }
386                    }
387                    _ = signal_tx.closed() => {
388                        trace!("signal received, not accepting new connections");
389                        break;
390                    }
391                };
392
393                if let Some(nodelay) = tcp_nodelay {
394                    if let Err(err) = tcp_stream.set_nodelay(nodelay) {
395                        trace!("failed to set TCP_NODELAY on incoming connection: {err:#}");
396                    }
397                }
398
399                let tcp_stream = TokioIo::new(tcp_stream);
400
401                trace!("connection {remote_addr} accepted");
402
403                poll_fn(|cx| make_service.poll_ready(cx))
404                    .await
405                    .unwrap_or_else(|err| match err {});
406
407                let tower_service = make_service
408                    .call(IncomingStream {
409                        tcp_stream: &tcp_stream,
410                        remote_addr,
411                    })
412                    .await
413                    .unwrap_or_else(|err| match err {})
414                    .map_request(|req: Request<Incoming>| req.map(Body::new));
415
416                let hyper_service = TowerToHyperService::new(tower_service);
417
418                let signal_tx = Arc::clone(&signal_tx);
419
420                let close_rx = close_rx.clone();
421
422                tokio::spawn(async move {
423                    let builder = Builder::new(TokioExecutor::new());
424                    let conn = builder.serve_connection_with_upgrades(tcp_stream, hyper_service);
425                    pin_mut!(conn);
426
427                    let signal_closed = signal_tx.closed().fuse();
428                    pin_mut!(signal_closed);
429
430                    loop {
431                        tokio::select! {
432                            result = conn.as_mut() => {
433                                if let Err(_err) = result {
434                                    trace!("failed to serve connection: {_err:#}");
435                                }
436                                break;
437                            }
438                            _ = &mut signal_closed => {
439                                trace!("signal received in task, starting graceful shutdown");
440                                conn.as_mut().graceful_shutdown();
441                            }
442                        }
443                    }
444
445                    trace!("connection {remote_addr} closed");
446
447                    drop(close_rx);
448                });
449            }
450
451            drop(close_rx);
452            drop(tcp_listener);
453
454            trace!(
455                "waiting for {} task(s) to finish",
456                close_tx.receiver_count()
457            );
458            close_tx.closed().await;
459
460            Ok(())
461        }))
462    }
463}
464
465fn is_connection_error(e: &io::Error) -> bool {
466    matches!(
467        e.kind(),
468        io::ErrorKind::ConnectionRefused
469            | io::ErrorKind::ConnectionAborted
470            | io::ErrorKind::ConnectionReset
471    )
472}
473
474async fn tcp_accept(listener: &TcpListener) -> Option<(TcpStream, SocketAddr)> {
475    match listener.accept().await {
476        Ok(conn) => Some(conn),
477        Err(e) => {
478            if is_connection_error(&e) {
479                return None;
480            }
481
482            // [From `hyper::Server` in 0.14](https://github.com/hyperium/hyper/blob/v0.14.27/src/server/tcp.rs#L186)
483            //
484            // > A possible scenario is that the process has hit the max open files
485            // > allowed, and so trying to accept a new connection will fail with
486            // > `EMFILE`. In some cases, it's preferable to just wait for some time, if
487            // > the application will likely close some files (or connections), and try
488            // > to accept the connection again. If this option is `true`, the error
489            // > will be logged at the `error` level, since it is still a big deal,
490            // > and then the listener will sleep for 1 second.
491            //
492            // hyper allowed customizing this but axum does not.
493            error!("accept error: {e}");
494            tokio::time::sleep(Duration::from_secs(1)).await;
495            None
496        }
497    }
498}
499
500mod private {
501    use std::{
502        future::Future,
503        io,
504        pin::Pin,
505        task::{Context, Poll},
506    };
507
508    pub struct ServeFuture(pub(super) futures_util::future::BoxFuture<'static, io::Result<()>>);
509
510    impl Future for ServeFuture {
511        type Output = io::Result<()>;
512
513        #[inline]
514        fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
515            self.0.as_mut().poll(cx)
516        }
517    }
518
519    impl std::fmt::Debug for ServeFuture {
520        fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
521            f.debug_struct("ServeFuture").finish_non_exhaustive()
522        }
523    }
524}
525
526/// An incoming stream.
527///
528/// Used with [`serve`] and [`IntoMakeServiceWithConnectInfo`].
529///
530/// [`IntoMakeServiceWithConnectInfo`]: crate::extract::connect_info::IntoMakeServiceWithConnectInfo
531#[derive(Debug)]
532pub struct IncomingStream<'a> {
533    tcp_stream: &'a TokioIo<TcpStream>,
534    remote_addr: SocketAddr,
535}
536
537impl IncomingStream<'_> {
538    /// Returns the local address that this stream is bound to.
539    pub fn local_addr(&self) -> std::io::Result<SocketAddr> {
540        self.tcp_stream.inner().local_addr()
541    }
542
543    /// Returns the remote address that this stream is bound to.
544    pub fn remote_addr(&self) -> SocketAddr {
545        self.remote_addr
546    }
547}
548
549#[cfg(test)]
550mod tests {
551    use super::*;
552    use crate::{
553        handler::{Handler, HandlerWithoutStateExt},
554        routing::get,
555        Router,
556    };
557    use std::{
558        future::pending,
559        net::{IpAddr, Ipv4Addr},
560    };
561
562    #[allow(dead_code, unused_must_use)]
563    async fn if_it_compiles_it_works() {
564        let router: Router = Router::new();
565
566        let addr = "0.0.0.0:0";
567
568        // router
569        serve(TcpListener::bind(addr).await.unwrap(), router.clone());
570        serve(
571            TcpListener::bind(addr).await.unwrap(),
572            router.clone().into_make_service(),
573        );
574        serve(
575            TcpListener::bind(addr).await.unwrap(),
576            router.into_make_service_with_connect_info::<SocketAddr>(),
577        );
578
579        // method router
580        serve(TcpListener::bind(addr).await.unwrap(), get(handler));
581        serve(
582            TcpListener::bind(addr).await.unwrap(),
583            get(handler).into_make_service(),
584        );
585        serve(
586            TcpListener::bind(addr).await.unwrap(),
587            get(handler).into_make_service_with_connect_info::<SocketAddr>(),
588        );
589
590        // handler
591        serve(
592            TcpListener::bind(addr).await.unwrap(),
593            handler.into_service(),
594        );
595        serve(
596            TcpListener::bind(addr).await.unwrap(),
597            handler.with_state(()),
598        );
599        serve(
600            TcpListener::bind(addr).await.unwrap(),
601            handler.into_make_service(),
602        );
603        serve(
604            TcpListener::bind(addr).await.unwrap(),
605            handler.into_make_service_with_connect_info::<SocketAddr>(),
606        );
607
608        // nodelay
609        serve(
610            TcpListener::bind(addr).await.unwrap(),
611            handler.into_service(),
612        )
613        .tcp_nodelay(true);
614
615        serve(
616            TcpListener::bind(addr).await.unwrap(),
617            handler.into_service(),
618        )
619        .with_graceful_shutdown(async { /*...*/ })
620        .tcp_nodelay(true);
621    }
622
623    async fn handler() {}
624
625    #[crate::test]
626    async fn test_serve_local_addr() {
627        let router: Router = Router::new();
628        let addr = "0.0.0.0:0";
629
630        let server = serve(TcpListener::bind(addr).await.unwrap(), router.clone());
631        let address = server.local_addr().unwrap();
632
633        assert_eq!(address.ip(), IpAddr::V4(Ipv4Addr::new(0, 0, 0, 0)));
634        assert_ne!(address.port(), 0);
635    }
636
637    #[crate::test]
638    async fn test_with_graceful_shutdown_local_addr() {
639        let router: Router = Router::new();
640        let addr = "0.0.0.0:0";
641
642        let server = serve(TcpListener::bind(addr).await.unwrap(), router.clone())
643            .with_graceful_shutdown(pending());
644        let address = server.local_addr().unwrap();
645
646        assert_eq!(address.ip(), IpAddr::V4(Ipv4Addr::new(0, 0, 0, 0)));
647        assert_ne!(address.port(), 0);
648    }
649}