tonic/transport/server/service/
tls.rs

1use std::{fmt, io::Cursor, sync::Arc};
2
3use tokio::io::{AsyncRead, AsyncWrite};
4use tokio_rustls::{
5    rustls::{server::WebPkiClientVerifier, RootCertStore, ServerConfig},
6    server::TlsStream,
7    TlsAcceptor as RustlsAcceptor,
8};
9
10use crate::transport::{
11    service::tls::{add_certs_from_pem, load_identity, ALPN_H2},
12    Certificate, Identity,
13};
14
15#[derive(Clone)]
16pub(crate) struct TlsAcceptor {
17    inner: Arc<ServerConfig>,
18}
19
20impl TlsAcceptor {
21    pub(crate) fn new(
22        identity: Identity,
23        client_ca_root: Option<Certificate>,
24        client_auth_optional: bool,
25    ) -> Result<Self, crate::Error> {
26        let builder = ServerConfig::builder();
27
28        let builder = match client_ca_root {
29            None => builder.with_no_client_auth(),
30            Some(cert) => {
31                let mut roots = RootCertStore::empty();
32                add_certs_from_pem(&mut Cursor::new(cert), &mut roots)?;
33                let verifier = if client_auth_optional {
34                    WebPkiClientVerifier::builder(roots.into()).allow_unauthenticated()
35                } else {
36                    WebPkiClientVerifier::builder(roots.into())
37                }
38                .build()?;
39                builder.with_client_cert_verifier(verifier)
40            }
41        };
42
43        let (cert, key) = load_identity(identity)?;
44        let mut config = builder.with_single_cert(cert, key)?;
45
46        config.alpn_protocols.push(ALPN_H2.into());
47        Ok(Self {
48            inner: Arc::new(config),
49        })
50    }
51
52    pub(crate) async fn accept<IO>(&self, io: IO) -> Result<TlsStream<IO>, crate::Error>
53    where
54        IO: AsyncRead + AsyncWrite + Unpin,
55    {
56        let acceptor = RustlsAcceptor::from(self.inner.clone());
57        acceptor.accept(io).await.map_err(Into::into)
58    }
59}
60
61impl fmt::Debug for TlsAcceptor {
62    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
63        f.debug_struct("TlsAcceptor").finish()
64    }
65}