gcp_auth/
custom_service_account.rs

1use std::collections::HashMap;
2use std::path::Path;
3use std::str::FromStr;
4use std::sync::Arc;
5
6use async_trait::async_trait;
7use base64::{engine::general_purpose::URL_SAFE, Engine};
8use bytes::Bytes;
9use chrono::Utc;
10use http_body_util::Full;
11use hyper::header::CONTENT_TYPE;
12use hyper::Request;
13use serde::Serialize;
14use tokio::sync::RwLock;
15use tracing::{debug, instrument, Level};
16use url::form_urlencoded;
17
18use crate::types::{HttpClient, ServiceAccountKey, Signer, Token};
19use crate::{Error, TokenProvider};
20
21/// A custom service account containing credentials
22///
23/// Once initialized, a [`CustomServiceAccount`] can be converted into an [`AuthenticationManager`]
24/// using the applicable `From` implementation.
25///
26/// [`AuthenticationManager`]: crate::AuthenticationManager
27#[derive(Debug)]
28pub struct CustomServiceAccount {
29    client: HttpClient,
30    credentials: ServiceAccountKey,
31    signer: Signer,
32    tokens: RwLock<HashMap<Vec<String>, Arc<Token>>>,
33    subject: Option<String>,
34    audience: Option<String>,
35}
36
37impl CustomServiceAccount {
38    /// Check `GOOGLE_APPLICATION_CREDENTIALS` environment variable for a path to JSON credentials
39    pub fn from_env() -> Result<Option<Self>, Error> {
40        debug!("check for GOOGLE_APPLICATION_CREDENTIALS env var");
41        match ServiceAccountKey::from_env()? {
42            Some(credentials) => Self::new(credentials, HttpClient::new()?).map(Some),
43            None => Ok(None),
44        }
45    }
46
47    /// Read service account credentials from the given JSON file
48    pub fn from_file<T: AsRef<Path>>(path: T) -> Result<Self, Error> {
49        Self::new(ServiceAccountKey::from_file(path)?, HttpClient::new()?)
50    }
51
52    /// Read service account credentials from the given JSON string
53    pub fn from_json(s: &str) -> Result<Self, Error> {
54        Self::new(ServiceAccountKey::from_str(s)?, HttpClient::new()?)
55    }
56
57    /// Set the `subject` to impersonate a user
58    pub fn with_subject(mut self, subject: String) -> Self {
59        self.subject = Some(subject);
60        self
61    }
62
63    /// Set the `Audience` to impersonate a user
64    pub fn with_audience(mut self, audience: String) -> Self {
65        self.audience = Some(audience);
66        self
67    }
68
69    fn new(credentials: ServiceAccountKey, client: HttpClient) -> Result<Self, Error> {
70        debug!(project = ?credentials.project_id, email = credentials.client_email, "found credentials");
71        Ok(Self {
72            client,
73            signer: Signer::new(&credentials.private_key)?,
74            credentials,
75            tokens: RwLock::new(HashMap::new()),
76            subject: None,
77            audience: None,
78        })
79    }
80
81    #[instrument(level = Level::DEBUG, skip(self))]
82    async fn fetch_token(&self, scopes: &[&str]) -> Result<Arc<Token>, Error> {
83        let jwt = Claims::new(
84            &self.credentials,
85            scopes,
86            self.subject.as_deref(),
87            self.audience.as_deref(),
88        )
89        .to_jwt(&self.signer)?;
90        let body = Bytes::from(
91            form_urlencoded::Serializer::new(String::new())
92                .extend_pairs(&[("grant_type", GRANT_TYPE), ("assertion", jwt.as_str())])
93                .finish()
94                .into_bytes(),
95        );
96
97        let token = self
98            .client
99            .token(
100                &|| {
101                    Request::post(&self.credentials.token_uri)
102                        .header(CONTENT_TYPE, "application/x-www-form-urlencoded")
103                        .body(Full::from(body.clone()))
104                        .unwrap()
105                },
106                "CustomServiceAccount",
107            )
108            .await?;
109
110        Ok(token)
111    }
112
113    /// The RSA PKCS1 SHA256 [`Signer`] used to sign JWT tokens
114    pub fn signer(&self) -> &Signer {
115        &self.signer
116    }
117
118    /// The project ID as found in the credentials
119    pub fn project_id(&self) -> Option<&str> {
120        self.credentials.project_id.as_deref()
121    }
122
123    /// The private key as found in the credentials
124    pub fn private_key_pem(&self) -> &str {
125        &self.credentials.private_key
126    }
127}
128
129#[async_trait]
130impl TokenProvider for CustomServiceAccount {
131    async fn token(&self, scopes: &[&str]) -> Result<Arc<Token>, Error> {
132        let key: Vec<_> = scopes.iter().map(|x| x.to_string()).collect();
133        let token = self.tokens.read().await.get(&key).cloned();
134        if let Some(token) = token {
135            if !token.has_expired() {
136                return Ok(token.clone());
137            }
138
139            let mut locked = self.tokens.write().await;
140            let token = self.fetch_token(scopes).await?;
141            locked.insert(key, token.clone());
142            return Ok(token);
143        }
144
145        let mut locked = self.tokens.write().await;
146        let token = self.fetch_token(scopes).await?;
147        locked.insert(key, token.clone());
148        return Ok(token);
149    }
150
151    async fn project_id(&self) -> Result<Arc<str>, Error> {
152        match &self.credentials.project_id {
153            Some(pid) => Ok(pid.clone()),
154            None => Err(Error::Str("no project ID in application credentials")),
155        }
156    }
157}
158
159/// Permissions requested for a JWT.
160/// See https://developers.google.com/identity/protocols/OAuth2ServiceAccount#authorizingrequests.
161#[derive(Serialize, Debug)]
162pub(crate) struct Claims<'a> {
163    iss: &'a str,
164    aud: &'a str,
165    exp: i64,
166    iat: i64,
167    sub: Option<&'a str>,
168    scope: String,
169}
170
171impl<'a> Claims<'a> {
172    pub(crate) fn new(
173        key: &'a ServiceAccountKey,
174        scopes: &[&str],
175        sub: Option<&'a str>,
176        aud: Option<&'a str>,
177    ) -> Self {
178        let mut scope = String::with_capacity(16);
179        for (i, s) in scopes.iter().enumerate() {
180            if i != 0 {
181                scope.push(' ');
182            }
183
184            scope.push_str(s);
185        }
186
187        let iat = Utc::now().timestamp();
188        Claims {
189            iss: &key.client_email,
190            aud: aud.unwrap_or(&key.token_uri),
191            exp: iat + 3600 - 5, // Max validity is 1h
192            iat,
193            sub,
194            scope,
195        }
196    }
197
198    pub(crate) fn to_jwt(&self, signer: &Signer) -> Result<String, Error> {
199        let mut jwt = String::new();
200        URL_SAFE.encode_string(GOOGLE_RS256_HEAD, &mut jwt);
201        jwt.push('.');
202        URL_SAFE.encode_string(serde_json::to_string(self).unwrap(), &mut jwt);
203
204        let signature = signer.sign(jwt.as_bytes())?;
205        jwt.push('.');
206        URL_SAFE.encode_string(&signature, &mut jwt);
207        Ok(jwt)
208    }
209}
210
211pub(crate) const GRANT_TYPE: &str = "urn:ietf:params:oauth:grant-type:jwt-bearer";
212const GOOGLE_RS256_HEAD: &str = r#"{"alg":"RS256","typ":"JWT"}"#;