1use std::{
4 convert::Infallible,
5 fmt::Debug,
6 future::{poll_fn, Future, IntoFuture},
7 io,
8 marker::PhantomData,
9 net::SocketAddr,
10 sync::Arc,
11 time::Duration,
12};
13
14use axum_core::{body::Body, extract::Request, response::Response};
15use futures_util::{pin_mut, FutureExt};
16use hyper::body::Incoming;
17use hyper_util::rt::{TokioExecutor, TokioIo};
18#[cfg(any(feature = "http1", feature = "http2"))]
19use hyper_util::{server::conn::auto::Builder, service::TowerToHyperService};
20use tokio::{
21 net::{TcpListener, TcpStream},
22 sync::watch,
23};
24use tower::ServiceExt as _;
25use tower_service::Service;
26
27#[cfg(all(feature = "tokio", any(feature = "http1", feature = "http2")))]
92pub fn serve<M, S>(tcp_listener: TcpListener, make_service: M) -> Serve<M, S>
93where
94 M: for<'a> Service<IncomingStream<'a>, Error = Infallible, Response = S>,
95 S: Service<Request, Response = Response, Error = Infallible> + Clone + Send + 'static,
96 S::Future: Send,
97{
98 Serve {
99 tcp_listener,
100 make_service,
101 tcp_nodelay: None,
102 _marker: PhantomData,
103 }
104}
105
106#[cfg(all(feature = "tokio", any(feature = "http1", feature = "http2")))]
108#[must_use = "futures must be awaited or polled"]
109pub struct Serve<M, S> {
110 tcp_listener: TcpListener,
111 make_service: M,
112 tcp_nodelay: Option<bool>,
113 _marker: PhantomData<S>,
114}
115
116#[cfg(all(feature = "tokio", any(feature = "http1", feature = "http2")))]
117impl<M, S> Serve<M, S> {
118 pub fn with_graceful_shutdown<F>(self, signal: F) -> WithGracefulShutdown<M, S, F>
140 where
141 F: Future<Output = ()> + Send + 'static,
142 {
143 WithGracefulShutdown {
144 tcp_listener: self.tcp_listener,
145 make_service: self.make_service,
146 signal,
147 tcp_nodelay: self.tcp_nodelay,
148 _marker: PhantomData,
149 }
150 }
151
152 pub fn tcp_nodelay(self, nodelay: bool) -> Self {
171 Self {
172 tcp_nodelay: Some(nodelay),
173 ..self
174 }
175 }
176
177 pub fn local_addr(&self) -> io::Result<SocketAddr> {
179 self.tcp_listener.local_addr()
180 }
181}
182
183#[cfg(all(feature = "tokio", any(feature = "http1", feature = "http2")))]
184impl<M, S> Debug for Serve<M, S>
185where
186 M: Debug,
187{
188 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
189 let Self {
190 tcp_listener,
191 make_service,
192 tcp_nodelay,
193 _marker: _,
194 } = self;
195
196 f.debug_struct("Serve")
197 .field("tcp_listener", tcp_listener)
198 .field("make_service", make_service)
199 .field("tcp_nodelay", tcp_nodelay)
200 .finish()
201 }
202}
203
204#[cfg(all(feature = "tokio", any(feature = "http1", feature = "http2")))]
205impl<M, S> IntoFuture for Serve<M, S>
206where
207 M: for<'a> Service<IncomingStream<'a>, Error = Infallible, Response = S> + Send + 'static,
208 for<'a> <M as Service<IncomingStream<'a>>>::Future: Send,
209 S: Service<Request, Response = Response, Error = Infallible> + Clone + Send + 'static,
210 S::Future: Send,
211{
212 type Output = io::Result<()>;
213 type IntoFuture = private::ServeFuture;
214
215 fn into_future(self) -> Self::IntoFuture {
216 private::ServeFuture(Box::pin(async move {
217 let Self {
218 tcp_listener,
219 mut make_service,
220 tcp_nodelay,
221 _marker: _,
222 } = self;
223
224 loop {
225 let (tcp_stream, remote_addr) = match tcp_accept(&tcp_listener).await {
226 Some(conn) => conn,
227 None => continue,
228 };
229
230 if let Some(nodelay) = tcp_nodelay {
231 if let Err(err) = tcp_stream.set_nodelay(nodelay) {
232 trace!("failed to set TCP_NODELAY on incoming connection: {err:#}");
233 }
234 }
235
236 let tcp_stream = TokioIo::new(tcp_stream);
237
238 poll_fn(|cx| make_service.poll_ready(cx))
239 .await
240 .unwrap_or_else(|err| match err {});
241
242 let tower_service = make_service
243 .call(IncomingStream {
244 tcp_stream: &tcp_stream,
245 remote_addr,
246 })
247 .await
248 .unwrap_or_else(|err| match err {})
249 .map_request(|req: Request<Incoming>| req.map(Body::new));
250
251 let hyper_service = TowerToHyperService::new(tower_service);
252
253 tokio::spawn(async move {
254 match Builder::new(TokioExecutor::new())
255 .serve_connection_with_upgrades(tcp_stream, hyper_service)
257 .await
258 {
259 Ok(()) => {}
260 Err(_err) => {
261 }
267 }
268 });
269 }
270 }))
271 }
272}
273
274#[cfg(all(feature = "tokio", any(feature = "http1", feature = "http2")))]
276#[must_use = "futures must be awaited or polled"]
277pub struct WithGracefulShutdown<M, S, F> {
278 tcp_listener: TcpListener,
279 make_service: M,
280 signal: F,
281 tcp_nodelay: Option<bool>,
282 _marker: PhantomData<S>,
283}
284
285impl<M, S, F> WithGracefulShutdown<M, S, F> {
286 pub fn tcp_nodelay(self, nodelay: bool) -> Self {
310 Self {
311 tcp_nodelay: Some(nodelay),
312 ..self
313 }
314 }
315
316 pub fn local_addr(&self) -> io::Result<SocketAddr> {
318 self.tcp_listener.local_addr()
319 }
320}
321
322#[cfg(all(feature = "tokio", any(feature = "http1", feature = "http2")))]
323impl<M, S, F> Debug for WithGracefulShutdown<M, S, F>
324where
325 M: Debug,
326 S: Debug,
327 F: Debug,
328{
329 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
330 let Self {
331 tcp_listener,
332 make_service,
333 signal,
334 tcp_nodelay,
335 _marker: _,
336 } = self;
337
338 f.debug_struct("WithGracefulShutdown")
339 .field("tcp_listener", tcp_listener)
340 .field("make_service", make_service)
341 .field("signal", signal)
342 .field("tcp_nodelay", tcp_nodelay)
343 .finish()
344 }
345}
346
347#[cfg(all(feature = "tokio", any(feature = "http1", feature = "http2")))]
348impl<M, S, F> IntoFuture for WithGracefulShutdown<M, S, F>
349where
350 M: for<'a> Service<IncomingStream<'a>, Error = Infallible, Response = S> + Send + 'static,
351 for<'a> <M as Service<IncomingStream<'a>>>::Future: Send,
352 S: Service<Request, Response = Response, Error = Infallible> + Clone + Send + 'static,
353 S::Future: Send,
354 F: Future<Output = ()> + Send + 'static,
355{
356 type Output = io::Result<()>;
357 type IntoFuture = private::ServeFuture;
358
359 fn into_future(self) -> Self::IntoFuture {
360 let Self {
361 tcp_listener,
362 mut make_service,
363 signal,
364 tcp_nodelay,
365 _marker: _,
366 } = self;
367
368 let (signal_tx, signal_rx) = watch::channel(());
369 let signal_tx = Arc::new(signal_tx);
370 tokio::spawn(async move {
371 signal.await;
372 trace!("received graceful shutdown signal. Telling tasks to shutdown");
373 drop(signal_rx);
374 });
375
376 let (close_tx, close_rx) = watch::channel(());
377
378 private::ServeFuture(Box::pin(async move {
379 loop {
380 let (tcp_stream, remote_addr) = tokio::select! {
381 conn = tcp_accept(&tcp_listener) => {
382 match conn {
383 Some(conn) => conn,
384 None => continue,
385 }
386 }
387 _ = signal_tx.closed() => {
388 trace!("signal received, not accepting new connections");
389 break;
390 }
391 };
392
393 if let Some(nodelay) = tcp_nodelay {
394 if let Err(err) = tcp_stream.set_nodelay(nodelay) {
395 trace!("failed to set TCP_NODELAY on incoming connection: {err:#}");
396 }
397 }
398
399 let tcp_stream = TokioIo::new(tcp_stream);
400
401 trace!("connection {remote_addr} accepted");
402
403 poll_fn(|cx| make_service.poll_ready(cx))
404 .await
405 .unwrap_or_else(|err| match err {});
406
407 let tower_service = make_service
408 .call(IncomingStream {
409 tcp_stream: &tcp_stream,
410 remote_addr,
411 })
412 .await
413 .unwrap_or_else(|err| match err {})
414 .map_request(|req: Request<Incoming>| req.map(Body::new));
415
416 let hyper_service = TowerToHyperService::new(tower_service);
417
418 let signal_tx = Arc::clone(&signal_tx);
419
420 let close_rx = close_rx.clone();
421
422 tokio::spawn(async move {
423 let builder = Builder::new(TokioExecutor::new());
424 let conn = builder.serve_connection_with_upgrades(tcp_stream, hyper_service);
425 pin_mut!(conn);
426
427 let signal_closed = signal_tx.closed().fuse();
428 pin_mut!(signal_closed);
429
430 loop {
431 tokio::select! {
432 result = conn.as_mut() => {
433 if let Err(_err) = result {
434 trace!("failed to serve connection: {_err:#}");
435 }
436 break;
437 }
438 _ = &mut signal_closed => {
439 trace!("signal received in task, starting graceful shutdown");
440 conn.as_mut().graceful_shutdown();
441 }
442 }
443 }
444
445 trace!("connection {remote_addr} closed");
446
447 drop(close_rx);
448 });
449 }
450
451 drop(close_rx);
452 drop(tcp_listener);
453
454 trace!(
455 "waiting for {} task(s) to finish",
456 close_tx.receiver_count()
457 );
458 close_tx.closed().await;
459
460 Ok(())
461 }))
462 }
463}
464
465fn is_connection_error(e: &io::Error) -> bool {
466 matches!(
467 e.kind(),
468 io::ErrorKind::ConnectionRefused
469 | io::ErrorKind::ConnectionAborted
470 | io::ErrorKind::ConnectionReset
471 )
472}
473
474async fn tcp_accept(listener: &TcpListener) -> Option<(TcpStream, SocketAddr)> {
475 match listener.accept().await {
476 Ok(conn) => Some(conn),
477 Err(e) => {
478 if is_connection_error(&e) {
479 return None;
480 }
481
482 error!("accept error: {e}");
494 tokio::time::sleep(Duration::from_secs(1)).await;
495 None
496 }
497 }
498}
499
500mod private {
501 use std::{
502 future::Future,
503 io,
504 pin::Pin,
505 task::{Context, Poll},
506 };
507
508 pub struct ServeFuture(pub(super) futures_util::future::BoxFuture<'static, io::Result<()>>);
509
510 impl Future for ServeFuture {
511 type Output = io::Result<()>;
512
513 #[inline]
514 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
515 self.0.as_mut().poll(cx)
516 }
517 }
518
519 impl std::fmt::Debug for ServeFuture {
520 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
521 f.debug_struct("ServeFuture").finish_non_exhaustive()
522 }
523 }
524}
525
526#[derive(Debug)]
532pub struct IncomingStream<'a> {
533 tcp_stream: &'a TokioIo<TcpStream>,
534 remote_addr: SocketAddr,
535}
536
537impl IncomingStream<'_> {
538 pub fn local_addr(&self) -> std::io::Result<SocketAddr> {
540 self.tcp_stream.inner().local_addr()
541 }
542
543 pub fn remote_addr(&self) -> SocketAddr {
545 self.remote_addr
546 }
547}
548
549#[cfg(test)]
550mod tests {
551 use super::*;
552 use crate::{
553 handler::{Handler, HandlerWithoutStateExt},
554 routing::get,
555 Router,
556 };
557 use std::{
558 future::pending,
559 net::{IpAddr, Ipv4Addr},
560 };
561
562 #[allow(dead_code, unused_must_use)]
563 async fn if_it_compiles_it_works() {
564 let router: Router = Router::new();
565
566 let addr = "0.0.0.0:0";
567
568 serve(TcpListener::bind(addr).await.unwrap(), router.clone());
570 serve(
571 TcpListener::bind(addr).await.unwrap(),
572 router.clone().into_make_service(),
573 );
574 serve(
575 TcpListener::bind(addr).await.unwrap(),
576 router.into_make_service_with_connect_info::<SocketAddr>(),
577 );
578
579 serve(TcpListener::bind(addr).await.unwrap(), get(handler));
581 serve(
582 TcpListener::bind(addr).await.unwrap(),
583 get(handler).into_make_service(),
584 );
585 serve(
586 TcpListener::bind(addr).await.unwrap(),
587 get(handler).into_make_service_with_connect_info::<SocketAddr>(),
588 );
589
590 serve(
592 TcpListener::bind(addr).await.unwrap(),
593 handler.into_service(),
594 );
595 serve(
596 TcpListener::bind(addr).await.unwrap(),
597 handler.with_state(()),
598 );
599 serve(
600 TcpListener::bind(addr).await.unwrap(),
601 handler.into_make_service(),
602 );
603 serve(
604 TcpListener::bind(addr).await.unwrap(),
605 handler.into_make_service_with_connect_info::<SocketAddr>(),
606 );
607
608 serve(
610 TcpListener::bind(addr).await.unwrap(),
611 handler.into_service(),
612 )
613 .tcp_nodelay(true);
614
615 serve(
616 TcpListener::bind(addr).await.unwrap(),
617 handler.into_service(),
618 )
619 .with_graceful_shutdown(async { })
620 .tcp_nodelay(true);
621 }
622
623 async fn handler() {}
624
625 #[crate::test]
626 async fn test_serve_local_addr() {
627 let router: Router = Router::new();
628 let addr = "0.0.0.0:0";
629
630 let server = serve(TcpListener::bind(addr).await.unwrap(), router.clone());
631 let address = server.local_addr().unwrap();
632
633 assert_eq!(address.ip(), IpAddr::V4(Ipv4Addr::new(0, 0, 0, 0)));
634 assert_ne!(address.port(), 0);
635 }
636
637 #[crate::test]
638 async fn test_with_graceful_shutdown_local_addr() {
639 let router: Router = Router::new();
640 let addr = "0.0.0.0:0";
641
642 let server = serve(TcpListener::bind(addr).await.unwrap(), router.clone())
643 .with_graceful_shutdown(pending());
644 let address = server.local_addr().unwrap();
645
646 assert_eq!(address.ip(), IpAddr::V4(Ipv4Addr::new(0, 0, 0, 0)));
647 assert_ne!(address.port(), 0);
648 }
649}