tonic/transport/channel/service/
tls.rs1use std::fmt;
2use std::io::Cursor;
3use std::sync::Arc;
4
5use hyper_util::rt::TokioIo;
6use tokio::io::{AsyncRead, AsyncWrite};
7use tokio_rustls::{
8 rustls::{
9 pki_types::{ServerName, TrustAnchor},
10 ClientConfig, RootCertStore,
11 },
12 TlsConnector as RustlsConnector,
13};
14
15use super::io::BoxedIo;
16use crate::transport::service::tls::{add_certs_from_pem, load_identity, TlsError, ALPN_H2};
17use crate::transport::tls::{Certificate, Identity};
18
19#[derive(Clone)]
20pub(crate) struct TlsConnector {
21 config: Arc<ClientConfig>,
22 domain: Arc<ServerName<'static>>,
23 assume_http2: bool,
24}
25
26impl TlsConnector {
27 pub(crate) fn new(
28 ca_certs: Vec<Certificate>,
29 trust_anchors: Vec<TrustAnchor<'static>>,
30 identity: Option<Identity>,
31 domain: &str,
32 assume_http2: bool,
33 #[cfg(feature = "tls-native-roots")] with_native_roots: bool,
34 #[cfg(feature = "tls-webpki-roots")] with_webpki_roots: bool,
35 ) -> Result<Self, crate::Error> {
36 let builder = ClientConfig::builder();
37 let mut roots = RootCertStore::from_iter(trust_anchors);
38
39 #[cfg(feature = "tls-native-roots")]
40 if with_native_roots {
41 let rustls_native_certs::CertificateResult { certs, errors, .. } =
42 rustls_native_certs::load_native_certs();
43 if !errors.is_empty() {
44 tracing::debug!("errors occured when loading native certs: {errors:?}");
45 }
46 if certs.is_empty() {
47 return Err(TlsError::NativeCertsNotFound.into());
48 }
49 roots.add_parsable_certificates(certs);
50 }
51
52 #[cfg(feature = "tls-webpki-roots")]
53 if with_webpki_roots {
54 roots.extend(webpki_roots::TLS_SERVER_ROOTS.iter().cloned());
55 }
56
57 for cert in ca_certs {
58 add_certs_from_pem(&mut Cursor::new(cert), &mut roots)?;
59 }
60
61 let builder = builder.with_root_certificates(roots);
62 let mut config = match identity {
63 Some(identity) => {
64 let (client_cert, client_key) = load_identity(identity)?;
65 builder.with_client_auth_cert(client_cert, client_key)?
66 }
67 None => builder.with_no_client_auth(),
68 };
69
70 config.alpn_protocols.push(ALPN_H2.into());
71 Ok(Self {
72 config: Arc::new(config),
73 domain: Arc::new(ServerName::try_from(domain)?.to_owned()),
74 assume_http2,
75 })
76 }
77
78 pub(crate) async fn connect<I>(&self, io: I) -> Result<BoxedIo, crate::Error>
79 where
80 I: AsyncRead + AsyncWrite + Send + Unpin + 'static,
81 {
82 let io = RustlsConnector::from(self.config.clone())
83 .connect(self.domain.as_ref().to_owned(), io)
84 .await?;
85
86 let (_, session) = io.get_ref();
89 let alpn_protocol = session.alpn_protocol();
90 if !(alpn_protocol == Some(ALPN_H2) || self.assume_http2) {
91 return Err(TlsError::H2NotNegotiated.into());
92 }
93 Ok(BoxedIo::new(TokioIo::new(io)))
94 }
95}
96
97impl fmt::Debug for TlsConnector {
98 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
99 f.debug_struct("TlsConnector").finish()
100 }
101}