From d727503e4d52d40d7215a42398b02d58e964ed01 Mon Sep 17 00:00:00 2001 From: Robin Lin <128118209+RobinLin666@users.noreply.github.com> Date: Sat, 21 Sep 2024 18:15:26 +0800 Subject: [PATCH] object_score: Support Azure Fabric OAuth Provider (#6382) * Update Azure dependencies and add support for Fabric token authentication * Refactor Azure credential provider to support Fabric token authentication * Refactor Azure credential provider to remove unnecessary print statements and improve token handling * Bump object_store version to 0.11.0 * Refactor Azure credential provider to remove unnecessary print statements and improve token handling --- object_store/src/azure/builder.rs | 88 ++++++++++++++++++++- object_store/src/azure/credential.rs | 114 ++++++++++++++++++++++++++- 2 files changed, 199 insertions(+), 3 deletions(-) diff --git a/object_store/src/azure/builder.rs b/object_store/src/azure/builder.rs index 0208073e85c..35cedeafc04 100644 --- a/object_store/src/azure/builder.rs +++ b/object_store/src/azure/builder.rs @@ -17,8 +17,8 @@ use crate::azure::client::{AzureClient, AzureConfig}; use crate::azure::credential::{ - AzureAccessKey, AzureCliCredential, ClientSecretOAuthProvider, ImdsManagedIdentityProvider, - WorkloadIdentityOAuthProvider, + AzureAccessKey, AzureCliCredential, ClientSecretOAuthProvider, FabricTokenOAuthProvider, + ImdsManagedIdentityProvider, WorkloadIdentityOAuthProvider, }; use crate::azure::{AzureCredential, AzureCredentialProvider, MicrosoftAzure, STORE}; use crate::client::TokenCredentialProvider; @@ -172,6 +172,14 @@ pub struct MicrosoftAzureBuilder { use_fabric_endpoint: ConfigValue, /// When set to true, skips tagging objects disable_tagging: ConfigValue, + /// Fabric token service url + fabric_token_service_url: Option, + /// Fabric workload host + fabric_workload_host: Option, + /// Fabric session token + fabric_session_token: Option, + /// Fabric cluster identifier + fabric_cluster_identifier: Option, } /// Configuration keys for [`MicrosoftAzureBuilder`] @@ -336,6 +344,34 @@ pub enum AzureConfigKey { /// - `disable_tagging` DisableTagging, + /// Fabric token service url + /// + /// Supported keys: + /// - `azure_fabric_token_service_url` + /// - `fabric_token_service_url` + FabricTokenServiceUrl, + + /// Fabric workload host + /// + /// Supported keys: + /// - `azure_fabric_workload_host` + /// - `fabric_workload_host` + FabricWorkloadHost, + + /// Fabric session token + /// + /// Supported keys: + /// - `azure_fabric_session_token` + /// - `fabric_session_token` + FabricSessionToken, + + /// Fabric cluster identifier + /// + /// Supported keys: + /// - `azure_fabric_cluster_identifier` + /// - `fabric_cluster_identifier` + FabricClusterIdentifier, + /// Client options Client(ClientConfigKey), } @@ -361,6 +397,10 @@ impl AsRef for AzureConfigKey { Self::SkipSignature => "azure_skip_signature", Self::ContainerName => "azure_container_name", Self::DisableTagging => "azure_disable_tagging", + Self::FabricTokenServiceUrl => "azure_fabric_token_service_url", + Self::FabricWorkloadHost => "azure_fabric_workload_host", + Self::FabricSessionToken => "azure_fabric_session_token", + Self::FabricClusterIdentifier => "azure_fabric_cluster_identifier", Self::Client(key) => key.as_ref(), } } @@ -406,6 +446,14 @@ impl FromStr for AzureConfigKey { "azure_skip_signature" | "skip_signature" => Ok(Self::SkipSignature), "azure_container_name" | "container_name" => Ok(Self::ContainerName), "azure_disable_tagging" | "disable_tagging" => Ok(Self::DisableTagging), + "azure_fabric_token_service_url" | "fabric_token_service_url" => { + Ok(Self::FabricTokenServiceUrl) + } + "azure_fabric_workload_host" | "fabric_workload_host" => Ok(Self::FabricWorkloadHost), + "azure_fabric_session_token" | "fabric_session_token" => Ok(Self::FabricSessionToken), + "azure_fabric_cluster_identifier" | "fabric_cluster_identifier" => { + Ok(Self::FabricClusterIdentifier) + } // Backwards compatibility "azure_allow_http" => Ok(Self::Client(ClientConfigKey::AllowHttp)), _ => match s.strip_prefix("azure_").unwrap_or(s).parse() { @@ -525,6 +573,14 @@ impl MicrosoftAzureBuilder { } AzureConfigKey::ContainerName => self.container_name = Some(value.into()), AzureConfigKey::DisableTagging => self.disable_tagging.parse(value), + AzureConfigKey::FabricTokenServiceUrl => { + self.fabric_token_service_url = Some(value.into()) + } + AzureConfigKey::FabricWorkloadHost => self.fabric_workload_host = Some(value.into()), + AzureConfigKey::FabricSessionToken => self.fabric_session_token = Some(value.into()), + AzureConfigKey::FabricClusterIdentifier => { + self.fabric_cluster_identifier = Some(value.into()) + } }; self } @@ -561,6 +617,10 @@ impl MicrosoftAzureBuilder { AzureConfigKey::Client(key) => self.client_options.get_config_value(key), AzureConfigKey::ContainerName => self.container_name.clone(), AzureConfigKey::DisableTagging => Some(self.disable_tagging.to_string()), + AzureConfigKey::FabricTokenServiceUrl => self.fabric_token_service_url.clone(), + AzureConfigKey::FabricWorkloadHost => self.fabric_workload_host.clone(), + AzureConfigKey::FabricSessionToken => self.fabric_session_token.clone(), + AzureConfigKey::FabricClusterIdentifier => self.fabric_cluster_identifier.clone(), } } @@ -856,6 +916,30 @@ impl MicrosoftAzureBuilder { let credential = if let Some(credential) = self.credentials { credential + } else if let ( + Some(fabric_token_service_url), + Some(fabric_workload_host), + Some(fabric_session_token), + Some(fabric_cluster_identifier), + ) = ( + &self.fabric_token_service_url, + &self.fabric_workload_host, + &self.fabric_session_token, + &self.fabric_cluster_identifier, + ) { + // This case should precede the bearer token case because it is more specific and will utilize the bearer token. + let fabric_credential = FabricTokenOAuthProvider::new( + fabric_token_service_url, + fabric_workload_host, + fabric_session_token, + fabric_cluster_identifier, + self.bearer_token.clone(), + ); + Arc::new(TokenCredentialProvider::new( + fabric_credential, + self.client_options.client()?, + self.retry_config.clone(), + )) as _ } else if let Some(bearer_token) = self.bearer_token { static_creds(AzureCredential::BearerToken(bearer_token)) } else if let Some(access_key) = self.access_key { diff --git a/object_store/src/azure/credential.rs b/object_store/src/azure/credential.rs index 7808c7c4a7c..6b5fa19d154 100644 --- a/object_store/src/azure/credential.rs +++ b/object_store/src/azure/credential.rs @@ -22,7 +22,7 @@ use crate::client::{CredentialProvider, TokenProvider}; use crate::util::hmac_sha256; use crate::RetryConfig; use async_trait::async_trait; -use base64::prelude::BASE64_STANDARD; +use base64::prelude::{BASE64_STANDARD, BASE64_URL_SAFE_NO_PAD}; use base64::Engine; use chrono::{DateTime, SecondsFormat, Utc}; use reqwest::header::{ @@ -51,10 +51,15 @@ pub(crate) static BLOB_TYPE: HeaderName = HeaderName::from_static("x-ms-blob-typ pub(crate) static DELETE_SNAPSHOTS: HeaderName = HeaderName::from_static("x-ms-delete-snapshots"); pub(crate) static COPY_SOURCE: HeaderName = HeaderName::from_static("x-ms-copy-source"); static CONTENT_MD5: HeaderName = HeaderName::from_static("content-md5"); +static PARTNER_TOKEN: HeaderName = HeaderName::from_static("x-ms-partner-token"); +static CLUSTER_IDENTIFIER: HeaderName = HeaderName::from_static("x-ms-cluster-identifier"); +static WORKLOAD_RESOURCE: HeaderName = HeaderName::from_static("x-ms-workload-resource-moniker"); +static PROXY_HOST: HeaderName = HeaderName::from_static("x-ms-proxy-host"); pub(crate) const RFC1123_FMT: &str = "%a, %d %h %Y %T GMT"; const CONTENT_TYPE_JSON: &str = "application/json"; const MSI_SECRET_ENV_KEY: &str = "IDENTITY_HEADER"; const MSI_API_VERSION: &str = "2019-08-01"; +const TOKEN_MIN_TTL: u64 = 300; /// OIDC scope used when interacting with OAuth2 APIs /// @@ -934,6 +939,113 @@ impl AzureCliCredential { } } +/// Encapsulates the logic to perform an OAuth token challenge for Fabric +#[derive(Debug)] +pub struct FabricTokenOAuthProvider { + fabric_token_service_url: String, + fabric_workload_host: String, + fabric_session_token: String, + fabric_cluster_identifier: String, + storage_access_token: Option, + token_expiry: Option, +} + +#[derive(Debug, Deserialize)] +struct Claims { + exp: u64, +} + +impl FabricTokenOAuthProvider { + /// Create a new [`FabricTokenOAuthProvider`] for an azure backed store + pub fn new( + fabric_token_service_url: impl Into, + fabric_workload_host: impl Into, + fabric_session_token: impl Into, + fabric_cluster_identifier: impl Into, + storage_access_token: Option, + ) -> Self { + let (storage_access_token, token_expiry) = match storage_access_token { + Some(token) => match Self::validate_and_get_expiry(&token) { + Some(expiry) if expiry > Self::get_current_timestamp() + TOKEN_MIN_TTL => { + (Some(token), Some(expiry)) + } + _ => (None, None), + }, + None => (None, None), + }; + + Self { + fabric_token_service_url: fabric_token_service_url.into(), + fabric_workload_host: fabric_workload_host.into(), + fabric_session_token: fabric_session_token.into(), + fabric_cluster_identifier: fabric_cluster_identifier.into(), + storage_access_token, + token_expiry, + } + } + + fn validate_and_get_expiry(token: &str) -> Option { + let payload = token.split('.').nth(1)?; + let decoded_bytes = BASE64_URL_SAFE_NO_PAD.decode(payload).ok()?; + let decoded_str = str::from_utf8(&decoded_bytes).ok()?; + let claims: Claims = serde_json::from_str(decoded_str).ok()?; + Some(claims.exp) + } + + fn get_current_timestamp() -> u64 { + SystemTime::now() + .duration_since(SystemTime::UNIX_EPOCH) + .map_or(0, |d| d.as_secs()) + } +} + +#[async_trait::async_trait] +impl TokenProvider for FabricTokenOAuthProvider { + type Credential = AzureCredential; + + /// Fetch a token + async fn fetch_token( + &self, + client: &Client, + retry: &RetryConfig, + ) -> crate::Result>> { + if let Some(storage_access_token) = &self.storage_access_token { + if let Some(expiry) = self.token_expiry { + let exp_in = expiry - Self::get_current_timestamp(); + if exp_in > TOKEN_MIN_TTL { + return Ok(TemporaryToken { + token: Arc::new(AzureCredential::BearerToken(storage_access_token.clone())), + expiry: Some(Instant::now() + Duration::from_secs(exp_in)), + }); + } + } + } + + let query_items = vec![("resource", AZURE_STORAGE_RESOURCE)]; + let access_token: String = client + .request(Method::GET, &self.fabric_token_service_url) + .header(&PARTNER_TOKEN, self.fabric_session_token.as_str()) + .header(&CLUSTER_IDENTIFIER, self.fabric_cluster_identifier.as_str()) + .header(&WORKLOAD_RESOURCE, self.fabric_cluster_identifier.as_str()) + .header(&PROXY_HOST, self.fabric_workload_host.as_str()) + .query(&query_items) + .retryable(retry) + .idempotent(true) + .send() + .await + .context(TokenRequestSnafu)? + .text() + .await + .context(TokenResponseBodySnafu)?; + let exp_in = Self::validate_and_get_expiry(&access_token) + .map_or(3600, |expiry| expiry - Self::get_current_timestamp()); + Ok(TemporaryToken { + token: Arc::new(AzureCredential::BearerToken(access_token)), + expiry: Some(Instant::now() + Duration::from_secs(exp_in)), + }) + } +} + #[async_trait] impl CredentialProvider for AzureCliCredential { type Credential = AzureCredential;