tonic/transport/channel/service/
tls.rs

1use std::fmt;
2use std::io::Cursor;
3use std::sync::Arc;
4
5use hyper_util::rt::TokioIo;
6use tokio::io::{AsyncRead, AsyncWrite};
7use tokio_rustls::{
8    rustls::{
9        pki_types::{ServerName, TrustAnchor},
10        ClientConfig, RootCertStore,
11    },
12    TlsConnector as RustlsConnector,
13};
14
15use super::io::BoxedIo;
16use crate::transport::service::tls::{add_certs_from_pem, load_identity, TlsError, ALPN_H2};
17use crate::transport::tls::{Certificate, Identity};
18
19#[derive(Clone)]
20pub(crate) struct TlsConnector {
21    config: Arc<ClientConfig>,
22    domain: Arc<ServerName<'static>>,
23    assume_http2: bool,
24}
25
26impl TlsConnector {
27    pub(crate) fn new(
28        ca_certs: Vec<Certificate>,
29        trust_anchors: Vec<TrustAnchor<'static>>,
30        identity: Option<Identity>,
31        domain: &str,
32        assume_http2: bool,
33        #[cfg(feature = "tls-native-roots")] with_native_roots: bool,
34        #[cfg(feature = "tls-webpki-roots")] with_webpki_roots: bool,
35    ) -> Result<Self, crate::Error> {
36        let builder = ClientConfig::builder();
37        let mut roots = RootCertStore::from_iter(trust_anchors);
38
39        #[cfg(feature = "tls-native-roots")]
40        if with_native_roots {
41            let rustls_native_certs::CertificateResult { certs, errors, .. } =
42                rustls_native_certs::load_native_certs();
43            if !errors.is_empty() {
44                tracing::debug!("errors occured when loading native certs: {errors:?}");
45            }
46            if certs.is_empty() {
47                return Err(TlsError::NativeCertsNotFound.into());
48            }
49            roots.add_parsable_certificates(certs);
50        }
51
52        #[cfg(feature = "tls-webpki-roots")]
53        if with_webpki_roots {
54            roots.extend(webpki_roots::TLS_SERVER_ROOTS.iter().cloned());
55        }
56
57        for cert in ca_certs {
58            add_certs_from_pem(&mut Cursor::new(cert), &mut roots)?;
59        }
60
61        let builder = builder.with_root_certificates(roots);
62        let mut config = match identity {
63            Some(identity) => {
64                let (client_cert, client_key) = load_identity(identity)?;
65                builder.with_client_auth_cert(client_cert, client_key)?
66            }
67            None => builder.with_no_client_auth(),
68        };
69
70        config.alpn_protocols.push(ALPN_H2.into());
71        Ok(Self {
72            config: Arc::new(config),
73            domain: Arc::new(ServerName::try_from(domain)?.to_owned()),
74            assume_http2,
75        })
76    }
77
78    pub(crate) async fn connect<I>(&self, io: I) -> Result<BoxedIo, crate::Error>
79    where
80        I: AsyncRead + AsyncWrite + Send + Unpin + 'static,
81    {
82        let io = RustlsConnector::from(self.config.clone())
83            .connect(self.domain.as_ref().to_owned(), io)
84            .await?;
85
86        // Generally we require ALPN to be negotiated, but if the user has
87        // explicitly set `assume_http2` to true, we'll allow it to be missing.
88        let (_, session) = io.get_ref();
89        let alpn_protocol = session.alpn_protocol();
90        if !(alpn_protocol == Some(ALPN_H2) || self.assume_http2) {
91            return Err(TlsError::H2NotNegotiated.into());
92        }
93        Ok(BoxedIo::new(TokioIo::new(io)))
94    }
95}
96
97impl fmt::Debug for TlsConnector {
98    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
99        f.debug_struct("TlsConnector").finish()
100    }
101}