Skip to content

Commit

Permalink
Cache access token in unftp-sbe-gcs (#435)
Browse files Browse the repository at this point in the history
This commit implements caching for access tokens in the Google Cloud Storage backend. Before this we re-requested the auth token before each GCS REST query.

I'm using an RwLock rather than a Mutex, since the token is read much more often than it's written.

I also changed the name of the workflow_identity module to workload_identity.

Resolves #384
  • Loading branch information
Werner Hofstra authored Oct 14, 2022
1 parent ca88d0f commit 772d6b7
Show file tree
Hide file tree
Showing 3 changed files with 146 additions and 18 deletions.
1 change: 1 addition & 0 deletions crates/unftp-sbe-gcs/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ mime = "0.3.16"
percent-encoding = "2.2.0"
serde = { version = "1.0.145", features = ["derive"] }
serde_json = "1.0.85"
time = "0.3.15"
tokio = { version = "1.21.1", features = ["rt", "net", "sync", "io-util", "time", "fs"] }
tokio-stream = "0.1.10"
tokio-util = { version = "0.7.4", features = ["codec", "compat"] }
Expand Down
163 changes: 145 additions & 18 deletions crates/unftp-sbe-gcs/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ pub mod object_metadata;
pub mod options;
mod response_body;
mod uri;
mod workflow_identity;
mod workload_identity;

pub use ext::ServerExt;

Expand All @@ -87,7 +87,9 @@ use response_body::{Item, ResponseBody};
use std::{
fmt::Debug,
path::{Path, PathBuf},
sync::Arc,
};
use tokio::sync::RwLock;
use tokio_util::codec::{BytesCodec, FramedRead};
use uri::GcsUri;
use yup_oauth2::ServiceAccountAuthenticator;
Expand All @@ -99,6 +101,8 @@ pub struct CloudStorage {
uris: GcsUri,
client: Client<HttpsConnector<HttpConnector>>,
auth: AuthMethod,

cached_token: CachedToken,
}

impl CloudStorage {
Expand Down Expand Up @@ -134,30 +138,110 @@ impl CloudStorage {
{
let client: Client<HttpsConnector<HttpConnector<GaiResolver>>, Body> =
Client::builder().build(HttpsConnectorBuilder::new().with_native_roots().https_or_http().enable_http1().build());
CloudStorage {
Self {
client,
auth: auth.into(),
uris: GcsUri::new(base_url.into(), bucket.into(), root),

cached_token: Default::default(),
}
}

// TODO: Cache the token. For `ServiceAccountKey`, the oauth client would already cache the token - we just need to move it to `CloudStorage`. For `WorkloadIdentity`, we can cache it in `CloudStorage`.
#[tracing_attributes::instrument]
async fn get_token(&self) -> Result<String, Error> {
async fn get_token_value(&self) -> Result<String, Error> {
if let Some(token) = self.cached_token.get().await {
return Ok(token.value);
}

let token = self.fetch_token().await?;
self.cached_token.set(token.clone()).await;
Ok(token.value)
}

async fn fetch_token(&self) -> Result<Token, Error> {
match &self.auth {
AuthMethod::ServiceAccountKey(_) => {
let key = self.auth.to_service_account_key()?;
let auth = ServiceAccountAuthenticator::builder(key).hyper_client(self.client.clone()).build().await?;

auth.token(&["https://www.googleapis.com/auth/devstorage.read_write"])
.map_ok(|t| t.as_str().to_string())
.map_ok(|t| t.into())
.map_err(|e| Error::new(ErrorKind::PermanentFileNotAvailable, e))
.await
}
AuthMethod::WorkloadIdentity(service) => workflow_identity::request_token(service.clone(), self.client.clone())
.await
.map(|t| t.access_token),
AuthMethod::None => Ok("unftp_test".to_string()),
AuthMethod::WorkloadIdentity(service) => workload_identity::request_token(service.clone(), self.client.clone()).await.map(|t| t.into()),
AuthMethod::None => Ok(Token {
value: "unftp_test".to_string(),
expires_at: None,
}),
}
}
}

#[derive(Default, Clone, Debug)]
struct CachedToken {
inner: Arc<RwLock<Option<Token>>>,
}

impl CachedToken {
// get returns a token if it's available and not expired, and None otherwise.
async fn get(&self) -> Option<Token> {
let cache = self.inner.read().await;
cache.as_ref().and_then(|token| token.get_if_active())
}

async fn set(&self, token: Token) {
let mut cache = self.inner.write().await;
*cache = Some(token);
}
}

#[derive(Clone, Debug, PartialEq, Eq)]
struct Token {
value: String,
expires_at: Option<time::OffsetDateTime>,
}

impl Token {
/// active yields true when the token is present and has not expired. In all other cases, it
/// returns false.
fn active(&self) -> bool {
self.expires_at
.map(|expires_at| {
let now = time::OffsetDateTime::now_utc();
const SAFETY_MARGIN: time::Duration = time::Duration::seconds(5);

expires_at > (now - SAFETY_MARGIN)
})
.unwrap_or(false)
}

fn get_if_active(&self) -> Option<Token> {
if self.active() {
Some(self.clone())
} else {
None
}
}
}

impl From<yup_oauth2::AccessToken> for Token {
fn from(source: yup_oauth2::AccessToken) -> Self {
Self {
value: source.as_str().to_string(),
expires_at: source.expiration_time(),
}
}
}

impl From<workload_identity::TokenResponse> for Token {
fn from(source: workload_identity::TokenResponse) -> Self {
let now = time::OffsetDateTime::now_utc();
let expires_in = time::Duration::seconds(source.expires_in.try_into().unwrap_or(0));

Self {
value: source.access_token,
expires_at: Some(now + expires_in),
}
}
}
Expand All @@ -176,7 +260,7 @@ impl<User: UserDetail> StorageBackend<User> for CloudStorage {

let client: Client<HttpsConnector<HttpConnector<GaiResolver>>, Body> = self.client.clone();

let token = self.get_token().await?;
let token = self.get_token_value().await?;
let request: Request<Body> = Request::builder()
.uri(uri)
.header(header::AUTHORIZATION, format!("Bearer {}", token))
Expand All @@ -203,7 +287,7 @@ impl<User: UserDetail> StorageBackend<User> for CloudStorage {

let client: Client<HttpsConnector<HttpConnector<GaiResolver>>, Body> = self.client.clone();

let token = self.get_token().await?;
let token = self.get_token_value().await?;
let request: Request<Body> = Request::builder()
.uri(uri)
.header(header::AUTHORIZATION, format!("Bearer {}", token))
Expand Down Expand Up @@ -231,7 +315,7 @@ impl<User: UserDetail> StorageBackend<User> for CloudStorage {

let client: Client<HttpsConnector<HttpConnector<GaiResolver>>, Body> = self.client.clone();

let token = self.get_token().await?;
let token = self.get_token_value().await?;

let request: Request<Body> = Request::builder()
.uri(uri)
Expand Down Expand Up @@ -265,7 +349,7 @@ impl<User: UserDetail> StorageBackend<User> for CloudStorage {
let uri: Uri = self.uris.get(path)?;
let client: Client<HttpsConnector<HttpConnector<GaiResolver>>, Body> = self.client.clone();

let token = self.get_token().await?;
let token = self.get_token_value().await?;
let request: Request<Body> = Request::builder()
.uri(uri)
.header(header::AUTHORIZATION, format!("Bearer {}", token))
Expand Down Expand Up @@ -299,7 +383,7 @@ impl<User: UserDetail> StorageBackend<User> for CloudStorage {

let reader = tokio::io::BufReader::with_capacity(4096, bytes);

let token = self.get_token().await?;
let token = self.get_token_value().await?;
let request: Request<Body> = Request::builder()
.uri(uri)
.header(header::AUTHORIZATION, format!("Bearer {}", token))
Expand All @@ -320,7 +404,7 @@ impl<User: UserDetail> StorageBackend<User> for CloudStorage {
let uri: Uri = self.uris.delete(path)?;

let client: Client<HttpsConnector<HttpConnector<GaiResolver>>, Body> = self.client.clone();
let token = self.get_token().await?;
let token = self.get_token_value().await?;
let request: Request<Body> = Request::builder()
.uri(uri)
.header(header::AUTHORIZATION, format!("Bearer {}", token))
Expand All @@ -338,7 +422,7 @@ impl<User: UserDetail> StorageBackend<User> for CloudStorage {
let uri: Uri = self.uris.mkd(path)?;
let client: Client<HttpsConnector<HttpConnector<GaiResolver>>, Body> = self.client.clone();

let token = self.get_token().await?;
let token = self.get_token_value().await?;
let request: Request<Body> = Request::builder()
.uri(uri)
.header(header::AUTHORIZATION, format!("Bearer {}", token))
Expand All @@ -364,7 +448,7 @@ impl<User: UserDetail> StorageBackend<User> for CloudStorage {
let uri: Uri = self.uris.dir_empty(&path)?;
let client: Client<HttpsConnector<HttpConnector<GaiResolver>>, Body> = self.client.clone();

let token = self.get_token().await?;
let token = self.get_token_value().await?;
let request: Request<Body> = Request::builder()
.uri(uri)
.header(header::AUTHORIZATION, format!("Bearer {}", token))
Expand Down Expand Up @@ -405,7 +489,7 @@ impl<User: UserDetail> StorageBackend<User> for CloudStorage {
let uri: Uri = self.uris.dir_empty(&path)?;
let client: Client<HttpsConnector<HttpConnector<GaiResolver>>, Body> = self.client.clone();

let token = self.get_token().await?;
let token = self.get_token_value().await?;
let request: Request<Body> = Request::builder()
.uri(uri)
.header(header::AUTHORIZATION, format!("Bearer {}", token))
Expand Down Expand Up @@ -451,3 +535,46 @@ fn result_based_on_http_status<T>(status: StatusCode, ok_val: T) -> Result<T, Er
}
Ok(ok_val)
}

#[cfg(test)]
mod tests {
use super::*;

#[tokio::test]
async fn cached_token() {
let cache: CachedToken = Default::default();

assert_eq!(cache.get().await, None);

cache
.set(Token {
value: "the_value".to_string(),
expires_at: None,
})
.await;
assert_eq!(cache.get().await, None);

cache
.set(Token {
value: "the_value".to_string(),
expires_at: Some(time::OffsetDateTime::now_utc() - time::Duration::seconds(10)),
})
.await;
assert_eq!(cache.get().await, None);

let in_future = Some(time::OffsetDateTime::now_utc() + time::Duration::seconds(10));
cache
.set(Token {
value: "the_value".to_string(),
expires_at: in_future.clone(),
})
.await;
assert_eq!(
cache.get().await,
Some(Token {
value: "the_value".to_string(),
expires_at: in_future,
}),
);
}
}

0 comments on commit 772d6b7

Please sign in to comment.