tonic/transport/service/
tls.rs

1use std::{fmt, io::Cursor};
2
3use tokio_rustls::rustls::{
4    pki_types::{CertificateDer, PrivateKeyDer},
5    RootCertStore,
6};
7
8use crate::transport::Identity;
9
10/// h2 alpn in plain format for rustls.
11pub(crate) const ALPN_H2: &[u8] = b"h2";
12
13#[derive(Debug)]
14pub(crate) enum TlsError {
15    #[cfg(feature = "channel")]
16    H2NotNegotiated,
17    #[cfg(feature = "tls-native-roots")]
18    NativeCertsNotFound,
19    CertificateParseError,
20    PrivateKeyParseError,
21}
22
23impl fmt::Display for TlsError {
24    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
25        match self {
26            #[cfg(feature = "channel")]
27            TlsError::H2NotNegotiated => write!(f, "HTTP/2 was not negotiated."),
28            #[cfg(feature = "tls-native-roots")]
29            TlsError::NativeCertsNotFound => write!(f, "no native certs found"),
30            TlsError::CertificateParseError => write!(f, "Error parsing TLS certificate."),
31            TlsError::PrivateKeyParseError => write!(
32                f,
33                "Error parsing TLS private key - no RSA or PKCS8-encoded keys found."
34            ),
35        }
36    }
37}
38
39impl std::error::Error for TlsError {}
40
41pub(crate) fn load_identity(
42    identity: Identity,
43) -> Result<(Vec<CertificateDer<'static>>, PrivateKeyDer<'static>), TlsError> {
44    let cert = rustls_pemfile::certs(&mut Cursor::new(identity.cert))
45        .collect::<Result<Vec<_>, _>>()
46        .map_err(|_| TlsError::CertificateParseError)?;
47
48    let Ok(Some(key)) = rustls_pemfile::private_key(&mut Cursor::new(identity.key)) else {
49        return Err(TlsError::PrivateKeyParseError);
50    };
51
52    Ok((cert, key))
53}
54
55pub(crate) fn add_certs_from_pem(
56    mut certs: &mut dyn std::io::BufRead,
57    roots: &mut RootCertStore,
58) -> Result<(), crate::Error> {
59    for cert in rustls_pemfile::certs(&mut certs).collect::<Result<Vec<_>, _>>()? {
60        roots
61            .add(cert)
62            .map_err(|_| TlsError::CertificateParseError)?;
63    }
64
65    Ok(())
66}