Skip to content

Commit

Permalink
object_score: Support Azure Fabric OAuth Provider (#6382)
Browse files Browse the repository at this point in the history
* Update Azure dependencies and add support for Fabric token authentication

* Refactor Azure credential provider to support Fabric token authentication

* Refactor Azure credential provider to remove unnecessary print statements and improve token handling

* Bump object_store version to 0.11.0

* Refactor Azure credential provider to remove unnecessary print statements and improve token handling
  • Loading branch information
RobinLin666 authored Sep 21, 2024
1 parent bc6009f commit d727503
Show file tree
Hide file tree
Showing 2 changed files with 199 additions and 3 deletions.
88 changes: 86 additions & 2 deletions object_store/src/azure/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@

use crate::azure::client::{AzureClient, AzureConfig};
use crate::azure::credential::{
AzureAccessKey, AzureCliCredential, ClientSecretOAuthProvider, ImdsManagedIdentityProvider,
WorkloadIdentityOAuthProvider,
AzureAccessKey, AzureCliCredential, ClientSecretOAuthProvider, FabricTokenOAuthProvider,
ImdsManagedIdentityProvider, WorkloadIdentityOAuthProvider,
};
use crate::azure::{AzureCredential, AzureCredentialProvider, MicrosoftAzure, STORE};
use crate::client::TokenCredentialProvider;
Expand Down Expand Up @@ -172,6 +172,14 @@ pub struct MicrosoftAzureBuilder {
use_fabric_endpoint: ConfigValue<bool>,
/// When set to true, skips tagging objects
disable_tagging: ConfigValue<bool>,
/// Fabric token service url
fabric_token_service_url: Option<String>,
/// Fabric workload host
fabric_workload_host: Option<String>,
/// Fabric session token
fabric_session_token: Option<String>,
/// Fabric cluster identifier
fabric_cluster_identifier: Option<String>,
}

/// Configuration keys for [`MicrosoftAzureBuilder`]
Expand Down Expand Up @@ -336,6 +344,34 @@ pub enum AzureConfigKey {
/// - `disable_tagging`
DisableTagging,

/// Fabric token service url
///
/// Supported keys:
/// - `azure_fabric_token_service_url`
/// - `fabric_token_service_url`
FabricTokenServiceUrl,

/// Fabric workload host
///
/// Supported keys:
/// - `azure_fabric_workload_host`
/// - `fabric_workload_host`
FabricWorkloadHost,

/// Fabric session token
///
/// Supported keys:
/// - `azure_fabric_session_token`
/// - `fabric_session_token`
FabricSessionToken,

/// Fabric cluster identifier
///
/// Supported keys:
/// - `azure_fabric_cluster_identifier`
/// - `fabric_cluster_identifier`
FabricClusterIdentifier,

/// Client options
Client(ClientConfigKey),
}
Expand All @@ -361,6 +397,10 @@ impl AsRef<str> for AzureConfigKey {
Self::SkipSignature => "azure_skip_signature",
Self::ContainerName => "azure_container_name",
Self::DisableTagging => "azure_disable_tagging",
Self::FabricTokenServiceUrl => "azure_fabric_token_service_url",
Self::FabricWorkloadHost => "azure_fabric_workload_host",
Self::FabricSessionToken => "azure_fabric_session_token",
Self::FabricClusterIdentifier => "azure_fabric_cluster_identifier",
Self::Client(key) => key.as_ref(),
}
}
Expand Down Expand Up @@ -406,6 +446,14 @@ impl FromStr for AzureConfigKey {
"azure_skip_signature" | "skip_signature" => Ok(Self::SkipSignature),
"azure_container_name" | "container_name" => Ok(Self::ContainerName),
"azure_disable_tagging" | "disable_tagging" => Ok(Self::DisableTagging),
"azure_fabric_token_service_url" | "fabric_token_service_url" => {
Ok(Self::FabricTokenServiceUrl)
}
"azure_fabric_workload_host" | "fabric_workload_host" => Ok(Self::FabricWorkloadHost),
"azure_fabric_session_token" | "fabric_session_token" => Ok(Self::FabricSessionToken),
"azure_fabric_cluster_identifier" | "fabric_cluster_identifier" => {
Ok(Self::FabricClusterIdentifier)
}
// Backwards compatibility
"azure_allow_http" => Ok(Self::Client(ClientConfigKey::AllowHttp)),
_ => match s.strip_prefix("azure_").unwrap_or(s).parse() {
Expand Down Expand Up @@ -525,6 +573,14 @@ impl MicrosoftAzureBuilder {
}
AzureConfigKey::ContainerName => self.container_name = Some(value.into()),
AzureConfigKey::DisableTagging => self.disable_tagging.parse(value),
AzureConfigKey::FabricTokenServiceUrl => {
self.fabric_token_service_url = Some(value.into())
}
AzureConfigKey::FabricWorkloadHost => self.fabric_workload_host = Some(value.into()),
AzureConfigKey::FabricSessionToken => self.fabric_session_token = Some(value.into()),
AzureConfigKey::FabricClusterIdentifier => {
self.fabric_cluster_identifier = Some(value.into())
}
};
self
}
Expand Down Expand Up @@ -561,6 +617,10 @@ impl MicrosoftAzureBuilder {
AzureConfigKey::Client(key) => self.client_options.get_config_value(key),
AzureConfigKey::ContainerName => self.container_name.clone(),
AzureConfigKey::DisableTagging => Some(self.disable_tagging.to_string()),
AzureConfigKey::FabricTokenServiceUrl => self.fabric_token_service_url.clone(),
AzureConfigKey::FabricWorkloadHost => self.fabric_workload_host.clone(),
AzureConfigKey::FabricSessionToken => self.fabric_session_token.clone(),
AzureConfigKey::FabricClusterIdentifier => self.fabric_cluster_identifier.clone(),
}
}

Expand Down Expand Up @@ -856,6 +916,30 @@ impl MicrosoftAzureBuilder {

let credential = if let Some(credential) = self.credentials {
credential
} else if let (
Some(fabric_token_service_url),
Some(fabric_workload_host),
Some(fabric_session_token),
Some(fabric_cluster_identifier),
) = (
&self.fabric_token_service_url,
&self.fabric_workload_host,
&self.fabric_session_token,
&self.fabric_cluster_identifier,
) {
// This case should precede the bearer token case because it is more specific and will utilize the bearer token.
let fabric_credential = FabricTokenOAuthProvider::new(
fabric_token_service_url,
fabric_workload_host,
fabric_session_token,
fabric_cluster_identifier,
self.bearer_token.clone(),
);
Arc::new(TokenCredentialProvider::new(
fabric_credential,
self.client_options.client()?,
self.retry_config.clone(),
)) as _
} else if let Some(bearer_token) = self.bearer_token {
static_creds(AzureCredential::BearerToken(bearer_token))
} else if let Some(access_key) = self.access_key {
Expand Down
114 changes: 113 additions & 1 deletion object_store/src/azure/credential.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ use crate::client::{CredentialProvider, TokenProvider};
use crate::util::hmac_sha256;
use crate::RetryConfig;
use async_trait::async_trait;
use base64::prelude::BASE64_STANDARD;
use base64::prelude::{BASE64_STANDARD, BASE64_URL_SAFE_NO_PAD};
use base64::Engine;
use chrono::{DateTime, SecondsFormat, Utc};
use reqwest::header::{
Expand Down Expand Up @@ -51,10 +51,15 @@ pub(crate) static BLOB_TYPE: HeaderName = HeaderName::from_static("x-ms-blob-typ
pub(crate) static DELETE_SNAPSHOTS: HeaderName = HeaderName::from_static("x-ms-delete-snapshots");
pub(crate) static COPY_SOURCE: HeaderName = HeaderName::from_static("x-ms-copy-source");
static CONTENT_MD5: HeaderName = HeaderName::from_static("content-md5");
static PARTNER_TOKEN: HeaderName = HeaderName::from_static("x-ms-partner-token");
static CLUSTER_IDENTIFIER: HeaderName = HeaderName::from_static("x-ms-cluster-identifier");
static WORKLOAD_RESOURCE: HeaderName = HeaderName::from_static("x-ms-workload-resource-moniker");
static PROXY_HOST: HeaderName = HeaderName::from_static("x-ms-proxy-host");
pub(crate) const RFC1123_FMT: &str = "%a, %d %h %Y %T GMT";
const CONTENT_TYPE_JSON: &str = "application/json";
const MSI_SECRET_ENV_KEY: &str = "IDENTITY_HEADER";
const MSI_API_VERSION: &str = "2019-08-01";
const TOKEN_MIN_TTL: u64 = 300;

/// OIDC scope used when interacting with OAuth2 APIs
///
Expand Down Expand Up @@ -934,6 +939,113 @@ impl AzureCliCredential {
}
}

/// Encapsulates the logic to perform an OAuth token challenge for Fabric
#[derive(Debug)]
pub struct FabricTokenOAuthProvider {
fabric_token_service_url: String,
fabric_workload_host: String,
fabric_session_token: String,
fabric_cluster_identifier: String,
storage_access_token: Option<String>,
token_expiry: Option<u64>,
}

#[derive(Debug, Deserialize)]
struct Claims {
exp: u64,
}

impl FabricTokenOAuthProvider {
/// Create a new [`FabricTokenOAuthProvider`] for an azure backed store
pub fn new(
fabric_token_service_url: impl Into<String>,
fabric_workload_host: impl Into<String>,
fabric_session_token: impl Into<String>,
fabric_cluster_identifier: impl Into<String>,
storage_access_token: Option<String>,
) -> Self {
let (storage_access_token, token_expiry) = match storage_access_token {
Some(token) => match Self::validate_and_get_expiry(&token) {
Some(expiry) if expiry > Self::get_current_timestamp() + TOKEN_MIN_TTL => {
(Some(token), Some(expiry))
}
_ => (None, None),
},
None => (None, None),
};

Self {
fabric_token_service_url: fabric_token_service_url.into(),
fabric_workload_host: fabric_workload_host.into(),
fabric_session_token: fabric_session_token.into(),
fabric_cluster_identifier: fabric_cluster_identifier.into(),
storage_access_token,
token_expiry,
}
}

fn validate_and_get_expiry(token: &str) -> Option<u64> {
let payload = token.split('.').nth(1)?;
let decoded_bytes = BASE64_URL_SAFE_NO_PAD.decode(payload).ok()?;
let decoded_str = str::from_utf8(&decoded_bytes).ok()?;
let claims: Claims = serde_json::from_str(decoded_str).ok()?;
Some(claims.exp)
}

fn get_current_timestamp() -> u64 {
SystemTime::now()
.duration_since(SystemTime::UNIX_EPOCH)
.map_or(0, |d| d.as_secs())
}
}

#[async_trait::async_trait]
impl TokenProvider for FabricTokenOAuthProvider {
type Credential = AzureCredential;

/// Fetch a token
async fn fetch_token(
&self,
client: &Client,
retry: &RetryConfig,
) -> crate::Result<TemporaryToken<Arc<AzureCredential>>> {
if let Some(storage_access_token) = &self.storage_access_token {
if let Some(expiry) = self.token_expiry {
let exp_in = expiry - Self::get_current_timestamp();
if exp_in > TOKEN_MIN_TTL {
return Ok(TemporaryToken {
token: Arc::new(AzureCredential::BearerToken(storage_access_token.clone())),
expiry: Some(Instant::now() + Duration::from_secs(exp_in)),
});
}
}
}

let query_items = vec![("resource", AZURE_STORAGE_RESOURCE)];
let access_token: String = client
.request(Method::GET, &self.fabric_token_service_url)
.header(&PARTNER_TOKEN, self.fabric_session_token.as_str())
.header(&CLUSTER_IDENTIFIER, self.fabric_cluster_identifier.as_str())
.header(&WORKLOAD_RESOURCE, self.fabric_cluster_identifier.as_str())
.header(&PROXY_HOST, self.fabric_workload_host.as_str())
.query(&query_items)
.retryable(retry)
.idempotent(true)
.send()
.await
.context(TokenRequestSnafu)?
.text()
.await
.context(TokenResponseBodySnafu)?;
let exp_in = Self::validate_and_get_expiry(&access_token)
.map_or(3600, |expiry| expiry - Self::get_current_timestamp());
Ok(TemporaryToken {
token: Arc::new(AzureCredential::BearerToken(access_token)),
expiry: Some(Instant::now() + Duration::from_secs(exp_in)),
})
}
}

#[async_trait]
impl CredentialProvider for AzureCliCredential {
type Credential = AzureCredential;
Expand Down

0 comments on commit d727503

Please sign in to comment.