gcp_auth/
custom_service_account.rs1use std::collections::HashMap;
2use std::path::Path;
3use std::str::FromStr;
4use std::sync::Arc;
5
6use async_trait::async_trait;
7use base64::{engine::general_purpose::URL_SAFE, Engine};
8use bytes::Bytes;
9use chrono::Utc;
10use http_body_util::Full;
11use hyper::header::CONTENT_TYPE;
12use hyper::Request;
13use serde::Serialize;
14use tokio::sync::RwLock;
15use tracing::{debug, instrument, Level};
16use url::form_urlencoded;
17
18use crate::types::{HttpClient, ServiceAccountKey, Signer, Token};
19use crate::{Error, TokenProvider};
20
21#[derive(Debug)]
28pub struct CustomServiceAccount {
29 client: HttpClient,
30 credentials: ServiceAccountKey,
31 signer: Signer,
32 tokens: RwLock<HashMap<Vec<String>, Arc<Token>>>,
33 subject: Option<String>,
34 audience: Option<String>,
35}
36
37impl CustomServiceAccount {
38 pub fn from_env() -> Result<Option<Self>, Error> {
40 debug!("check for GOOGLE_APPLICATION_CREDENTIALS env var");
41 match ServiceAccountKey::from_env()? {
42 Some(credentials) => Self::new(credentials, HttpClient::new()?).map(Some),
43 None => Ok(None),
44 }
45 }
46
47 pub fn from_file<T: AsRef<Path>>(path: T) -> Result<Self, Error> {
49 Self::new(ServiceAccountKey::from_file(path)?, HttpClient::new()?)
50 }
51
52 pub fn from_json(s: &str) -> Result<Self, Error> {
54 Self::new(ServiceAccountKey::from_str(s)?, HttpClient::new()?)
55 }
56
57 pub fn with_subject(mut self, subject: String) -> Self {
59 self.subject = Some(subject);
60 self
61 }
62
63 pub fn with_audience(mut self, audience: String) -> Self {
65 self.audience = Some(audience);
66 self
67 }
68
69 fn new(credentials: ServiceAccountKey, client: HttpClient) -> Result<Self, Error> {
70 debug!(project = ?credentials.project_id, email = credentials.client_email, "found credentials");
71 Ok(Self {
72 client,
73 signer: Signer::new(&credentials.private_key)?,
74 credentials,
75 tokens: RwLock::new(HashMap::new()),
76 subject: None,
77 audience: None,
78 })
79 }
80
81 #[instrument(level = Level::DEBUG, skip(self))]
82 async fn fetch_token(&self, scopes: &[&str]) -> Result<Arc<Token>, Error> {
83 let jwt = Claims::new(
84 &self.credentials,
85 scopes,
86 self.subject.as_deref(),
87 self.audience.as_deref(),
88 )
89 .to_jwt(&self.signer)?;
90 let body = Bytes::from(
91 form_urlencoded::Serializer::new(String::new())
92 .extend_pairs(&[("grant_type", GRANT_TYPE), ("assertion", jwt.as_str())])
93 .finish()
94 .into_bytes(),
95 );
96
97 let token = self
98 .client
99 .token(
100 &|| {
101 Request::post(&self.credentials.token_uri)
102 .header(CONTENT_TYPE, "application/x-www-form-urlencoded")
103 .body(Full::from(body.clone()))
104 .unwrap()
105 },
106 "CustomServiceAccount",
107 )
108 .await?;
109
110 Ok(token)
111 }
112
113 pub fn signer(&self) -> &Signer {
115 &self.signer
116 }
117
118 pub fn project_id(&self) -> Option<&str> {
120 self.credentials.project_id.as_deref()
121 }
122
123 pub fn private_key_pem(&self) -> &str {
125 &self.credentials.private_key
126 }
127}
128
129#[async_trait]
130impl TokenProvider for CustomServiceAccount {
131 async fn token(&self, scopes: &[&str]) -> Result<Arc<Token>, Error> {
132 let key: Vec<_> = scopes.iter().map(|x| x.to_string()).collect();
133 let token = self.tokens.read().await.get(&key).cloned();
134 if let Some(token) = token {
135 if !token.has_expired() {
136 return Ok(token.clone());
137 }
138
139 let mut locked = self.tokens.write().await;
140 let token = self.fetch_token(scopes).await?;
141 locked.insert(key, token.clone());
142 return Ok(token);
143 }
144
145 let mut locked = self.tokens.write().await;
146 let token = self.fetch_token(scopes).await?;
147 locked.insert(key, token.clone());
148 return Ok(token);
149 }
150
151 async fn project_id(&self) -> Result<Arc<str>, Error> {
152 match &self.credentials.project_id {
153 Some(pid) => Ok(pid.clone()),
154 None => Err(Error::Str("no project ID in application credentials")),
155 }
156 }
157}
158
159#[derive(Serialize, Debug)]
162pub(crate) struct Claims<'a> {
163 iss: &'a str,
164 aud: &'a str,
165 exp: i64,
166 iat: i64,
167 sub: Option<&'a str>,
168 scope: String,
169}
170
171impl<'a> Claims<'a> {
172 pub(crate) fn new(
173 key: &'a ServiceAccountKey,
174 scopes: &[&str],
175 sub: Option<&'a str>,
176 aud: Option<&'a str>,
177 ) -> Self {
178 let mut scope = String::with_capacity(16);
179 for (i, s) in scopes.iter().enumerate() {
180 if i != 0 {
181 scope.push(' ');
182 }
183
184 scope.push_str(s);
185 }
186
187 let iat = Utc::now().timestamp();
188 Claims {
189 iss: &key.client_email,
190 aud: aud.unwrap_or(&key.token_uri),
191 exp: iat + 3600 - 5, iat,
193 sub,
194 scope,
195 }
196 }
197
198 pub(crate) fn to_jwt(&self, signer: &Signer) -> Result<String, Error> {
199 let mut jwt = String::new();
200 URL_SAFE.encode_string(GOOGLE_RS256_HEAD, &mut jwt);
201 jwt.push('.');
202 URL_SAFE.encode_string(serde_json::to_string(self).unwrap(), &mut jwt);
203
204 let signature = signer.sign(jwt.as_bytes())?;
205 jwt.push('.');
206 URL_SAFE.encode_string(&signature, &mut jwt);
207 Ok(jwt)
208 }
209}
210
211pub(crate) const GRANT_TYPE: &str = "urn:ietf:params:oauth:grant-type:jwt-bearer";
212const GOOGLE_RS256_HEAD: &str = r#"{"alg":"RS256","typ":"JWT"}"#;