gcp_auth/
gcloud_authorized_user.rs

1use std::process::Command;
2use std::sync::Arc;
3use std::time::Duration;
4
5use async_trait::async_trait;
6use tokio::sync::RwLock;
7use tracing::{debug, instrument};
8
9use crate::types::Token;
10use crate::{Error, TokenProvider};
11
12/// A token provider that queries the `gcloud` CLI for access tokens
13#[derive(Debug)]
14pub struct GCloudAuthorizedUser {
15    project_id: Option<Arc<str>>,
16    token: RwLock<Arc<Token>>,
17}
18
19impl GCloudAuthorizedUser {
20    /// Check if `gcloud` is installed and logged in
21    pub async fn new() -> Result<Self, Error> {
22        debug!("try to print access token via `gcloud`");
23        let token = RwLock::new(Self::fetch_token()?);
24        let project_id = run(&["config", "get-value", "project"]).ok();
25        Ok(Self {
26            project_id: project_id.map(Arc::from),
27            token,
28        })
29    }
30
31    #[instrument(level = tracing::Level::DEBUG)]
32    fn fetch_token() -> Result<Arc<Token>, Error> {
33        Ok(Arc::new(Token::from_string(
34            run(&["auth", "print-access-token", "--quiet"])?,
35            DEFAULT_TOKEN_DURATION,
36        )))
37    }
38}
39
40#[async_trait]
41impl TokenProvider for GCloudAuthorizedUser {
42    async fn token(&self, _scopes: &[&str]) -> Result<Arc<Token>, Error> {
43        let token = self.token.read().await.clone();
44        if !token.has_expired() {
45            return Ok(token);
46        }
47
48        let mut locked = self.token.write().await;
49        let token = Self::fetch_token()?;
50        *locked = token.clone();
51        Ok(token)
52    }
53
54    async fn project_id(&self) -> Result<Arc<str>, Error> {
55        self.project_id
56            .clone()
57            .ok_or(Error::Str("failed to get project ID from `gcloud`"))
58    }
59}
60
61fn run(cmd: &[&str]) -> Result<String, Error> {
62    let mut command = Command::new(GCLOUD_CMD);
63    command.args(cmd);
64
65    let mut stdout = match command.output() {
66        Ok(output) if output.status.success() => output.stdout,
67        Ok(_) => return Err(Error::Str("running `gcloud` command failed")),
68        Err(err) => return Err(Error::Io("failed to run `gcloud`", err)),
69    };
70
71    while let Some(b' ' | b'\r' | b'\n') = stdout.last() {
72        stdout.pop();
73    }
74
75    String::from_utf8(stdout).map_err(|_| Error::Str("output from `gcloud` is not UTF-8"))
76}
77
78#[cfg(any(target_os = "linux", target_os = "macos"))]
79const GCLOUD_CMD: &str = "gcloud";
80
81#[cfg(target_os = "windows")]
82const GCLOUD_CMD: &str = "gcloud.cmd";
83
84/// The default number of seconds that it takes for a Google Cloud auth token to expire.
85/// This appears to be the default from practical testing, but we have not found evidence
86/// that this will always be the default duration.
87pub(crate) const DEFAULT_TOKEN_DURATION: Duration = Duration::from_secs(3600);
88
89#[cfg(test)]
90mod tests {
91    use chrono::Utc;
92
93    use super::*;
94
95    #[tokio::test]
96    #[ignore]
97    async fn gcloud() {
98        let gcloud = GCloudAuthorizedUser::new().await.unwrap();
99        println!("{:?}", gcloud.project_id);
100        if let Ok(t) = gcloud.token(&[""]).await {
101            let expires = Utc::now() + DEFAULT_TOKEN_DURATION;
102            println!("{:?}", t);
103            assert!(!t.has_expired());
104            assert!(t.expires_at() < expires + Duration::from_secs(1));
105            assert!(t.expires_at() > expires - Duration::from_secs(1));
106        } else {
107            panic!("GCloud Authorized User failed to get a token");
108        }
109    }
110
111    /// `gcloud_authorized_user` is the only user type to get a token that isn't deserialized from
112    /// JSON, and that doesn't include an expiry time. As such, the default token expiry time
113    /// functionality is tested here.
114    #[test]
115    fn test_token_from_string() {
116        let s = String::from("abc123");
117        let token = Token::from_string(s, DEFAULT_TOKEN_DURATION);
118        let expires = Utc::now() + DEFAULT_TOKEN_DURATION;
119
120        assert_eq!(token.as_str(), "abc123");
121        assert!(!token.has_expired());
122        assert!(token.expires_at() < expires + Duration::from_secs(1));
123        assert!(token.expires_at() > expires - Duration::from_secs(1));
124    }
125
126    #[test]
127    fn test_deserialize_no_time() {
128        let s = r#"{"access_token":"abc123"}"#;
129        let result = serde_json::from_str::<Token>(s)
130            .expect_err("Deserialization from JSON should fail when no expiry_time is included");
131
132        assert!(result.is_data());
133    }
134}